fixed open issues

This commit is contained in:
ottofabian 2021-05-17 17:58:33 +02:00
parent b39104a449
commit 14c60766c2
4 changed files with 56 additions and 73 deletions

View File

@ -214,16 +214,17 @@ register(
)
# MP environments
reacher_envs = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"]
for env in reacher_envs:
name = env.split("-")
## Simple Reacher
versions = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"]
for v in versions:
name = v.split("-")
register(
id=f'{name[0]}DMP-{name[1]}',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
# max_episode_steps=1,
kwargs={
"name": f"alr_envs:{env}",
"num_dof": 2 if "long" not in env.lower() else 5 ,
"name": f"alr_envs:{v}",
"num_dof": 2 if "long" not in v.lower() else 5,
"num_basis": 5,
"duration": 2,
"alpha_phase": 2,
@ -249,59 +250,33 @@ register(
}
)
register(
id='HoleReacherDMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
# max_episode_steps=1,
kwargs={
"name": "alr_envs:HoleReacher-v0",
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"learn_goal": True,
"alpha_phase": 2,
"bandwidth_factor": 2,
"policy_type": "velocity",
"weights_scale": 50,
"goal_scale": 0.1
}
)
## Hole Reacher
versions = ["v0", "v1", "v2"]
for v in versions:
register(
id=f'HoleReacherDMP-{v}',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
# max_episode_steps=1,
kwargs={
"name": f"alr_envs:HoleReacher-{v}",
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"learn_goal": True,
"alpha_phase": 2,
"bandwidth_factor": 2,
"policy_type": "velocity",
"weights_scale": 50,
"goal_scale": 0.1
}
)
register(
id='HoleReacherDMP-v1',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
# max_episode_steps=1,
kwargs={
"name": "alr_envs:HoleReacher-v1",
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"learn_goal": True,
"alpha_phase": 2,
"bandwidth_factor": 2,
"policy_type": "velocity",
"weights_scale": 50,
"goal_scale": 0.1
}
)
register(
id='HoleReacherDMP-v2',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
# max_episode_steps=1,
kwargs={
"name": "alr_envs:HoleReacher-v2",
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"learn_goal": True,
"alpha_phase": 2,
"bandwidth_factor": 2,
"policy_type": "velocity",
"weights_scale": 50,
"goal_scale": 0.1
}
)
# register(
# id='HoleReacherDetPMP-v0',
# entry_point='alr_envs.classic_control.hole_reacher:holereacher_detpmp',
# # max_episode_steps=1,
# # TODO: add mp kwargs
# )
# TODO: properly add final_pos
register(
@ -321,12 +296,7 @@ register(
}
)
# register(
# id='HoleReacherDetPMP-v0',
# entry_point='alr_envs.classic_control.hole_reacher:holereacher_detpmp',
# # max_episode_steps=1,
# # TODO: add mp kwargs
# )
## Ball in Cup
register(
id='ALRBallInACupSimpleDMP-v0',

View File

@ -127,12 +127,6 @@ class SimpleReacherEnv(MPEnv):
def _generate_goal(self):
if self._target is None:
# center = self._joints[0]
# # Sample uniformly in circle with radius R around center of reacher.
# R = np.sum(self.link_lengths)
# r = R * np.sqrt(self.np_random.uniform())
# theta = self.np_random.uniform() * 2 * np.pi
# goal = center + r * np.stack([np.cos(theta), np.sin(theta)])
total_length = np.sum(self.link_lengths)
goal = np.array([total_length, total_length])

View File

@ -99,9 +99,28 @@ class ViaPointReacher(MPEnv):
return self._get_obs().copy()
def _generate_goal(self):
self._via_point = self.np_random.uniform(0.5, 3.5, 2) if self._via_target is None else np.copy(self._via_target)
self._goal = self.np_random.uniform(0.5, 0.1, 2) if self._target is None else np.copy(self._target)
# raise NotImplementedError("How to properly sample points??")
# TODO: Maybe improve this later, this can yield quite a lot of invalid settings
total_length = np.sum(self.link_lengths)
# rejection sampled point in inner circle with 0.5*Radius
if self._via_target is None:
via_target = np.array([total_length, total_length])
while np.linalg.norm(via_target) >= 0.5 * total_length:
via_target = self.np_random.uniform(low=-0.5 * total_length, high=0.5 * total_length, size=2)
else:
via_target = np.copy(self._via_target)
# rejection sampled point in outer circle
if self._target is None:
goal = np.array([total_length, total_length])
while np.linalg.norm(goal) >= total_length or np.linalg.norm(goal) <= 0.5 * total_length:
goal = self.np_random.uniform(low=-total_length, high=total_length, size=2)
else:
goal = np.copy(self._target)
self._via_target = via_target
self._goal = goal
def _update_joints(self):
"""

View File

@ -70,7 +70,7 @@ class AlrMpEnvSampler:
class AlrContextualMpEnvSampler:
"""
An asynchronous sampler for non contextual MPWrapper environments. A sampler object can be called with a set of
An asynchronous sampler for contextual MPWrapper environments. A sampler object can be called with a set of
parameters and returns the corresponding final obs, rewards, dones and info dicts.
"""
def __init__(self, env_id, num_envs, seed=0, **env_kwargs):