fixed open issues
This commit is contained in:
parent
b39104a449
commit
14c60766c2
@ -214,16 +214,17 @@ register(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# MP environments
|
# MP environments
|
||||||
reacher_envs = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"]
|
## Simple Reacher
|
||||||
for env in reacher_envs:
|
versions = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"]
|
||||||
name = env.split("-")
|
for v in versions:
|
||||||
|
name = v.split("-")
|
||||||
register(
|
register(
|
||||||
id=f'{name[0]}DMP-{name[1]}',
|
id=f'{name[0]}DMP-{name[1]}',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"alr_envs:{env}",
|
"name": f"alr_envs:{v}",
|
||||||
"num_dof": 2 if "long" not in env.lower() else 5 ,
|
"num_dof": 2 if "long" not in v.lower() else 5,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
"alpha_phase": 2,
|
"alpha_phase": 2,
|
||||||
@ -249,12 +250,15 @@ register(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## Hole Reacher
|
||||||
|
versions = ["v0", "v1", "v2"]
|
||||||
|
for v in versions:
|
||||||
register(
|
register(
|
||||||
id='HoleReacherDMP-v0',
|
id=f'HoleReacherDMP-{v}',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "alr_envs:HoleReacher-v0",
|
"name": f"alr_envs:HoleReacher-{v}",
|
||||||
"num_dof": 5,
|
"num_dof": 5,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
@ -267,41 +271,12 @@ register(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
# register(
|
||||||
id='HoleReacherDMP-v1',
|
# id='HoleReacherDetPMP-v0',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
# entry_point='alr_envs.classic_control.hole_reacher:holereacher_detpmp',
|
||||||
# max_episode_steps=1,
|
# # max_episode_steps=1,
|
||||||
kwargs={
|
# # TODO: add mp kwargs
|
||||||
"name": "alr_envs:HoleReacher-v1",
|
# )
|
||||||
"num_dof": 5,
|
|
||||||
"num_basis": 5,
|
|
||||||
"duration": 2,
|
|
||||||
"learn_goal": True,
|
|
||||||
"alpha_phase": 2,
|
|
||||||
"bandwidth_factor": 2,
|
|
||||||
"policy_type": "velocity",
|
|
||||||
"weights_scale": 50,
|
|
||||||
"goal_scale": 0.1
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
register(
|
|
||||||
id='HoleReacherDMP-v2',
|
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
|
||||||
# max_episode_steps=1,
|
|
||||||
kwargs={
|
|
||||||
"name": "alr_envs:HoleReacher-v2",
|
|
||||||
"num_dof": 5,
|
|
||||||
"num_basis": 5,
|
|
||||||
"duration": 2,
|
|
||||||
"learn_goal": True,
|
|
||||||
"alpha_phase": 2,
|
|
||||||
"bandwidth_factor": 2,
|
|
||||||
"policy_type": "velocity",
|
|
||||||
"weights_scale": 50,
|
|
||||||
"goal_scale": 0.1
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: properly add final_pos
|
# TODO: properly add final_pos
|
||||||
register(
|
register(
|
||||||
@ -321,12 +296,7 @@ register(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# register(
|
## Ball in Cup
|
||||||
# id='HoleReacherDetPMP-v0',
|
|
||||||
# entry_point='alr_envs.classic_control.hole_reacher:holereacher_detpmp',
|
|
||||||
# # max_episode_steps=1,
|
|
||||||
# # TODO: add mp kwargs
|
|
||||||
# )
|
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRBallInACupSimpleDMP-v0',
|
id='ALRBallInACupSimpleDMP-v0',
|
||||||
|
@ -127,12 +127,6 @@ class SimpleReacherEnv(MPEnv):
|
|||||||
def _generate_goal(self):
|
def _generate_goal(self):
|
||||||
|
|
||||||
if self._target is None:
|
if self._target is None:
|
||||||
# center = self._joints[0]
|
|
||||||
# # Sample uniformly in circle with radius R around center of reacher.
|
|
||||||
# R = np.sum(self.link_lengths)
|
|
||||||
# r = R * np.sqrt(self.np_random.uniform())
|
|
||||||
# theta = self.np_random.uniform() * 2 * np.pi
|
|
||||||
# goal = center + r * np.stack([np.cos(theta), np.sin(theta)])
|
|
||||||
|
|
||||||
total_length = np.sum(self.link_lengths)
|
total_length = np.sum(self.link_lengths)
|
||||||
goal = np.array([total_length, total_length])
|
goal = np.array([total_length, total_length])
|
||||||
|
@ -99,9 +99,28 @@ class ViaPointReacher(MPEnv):
|
|||||||
return self._get_obs().copy()
|
return self._get_obs().copy()
|
||||||
|
|
||||||
def _generate_goal(self):
|
def _generate_goal(self):
|
||||||
self._via_point = self.np_random.uniform(0.5, 3.5, 2) if self._via_target is None else np.copy(self._via_target)
|
# TODO: Maybe improve this later, this can yield quite a lot of invalid settings
|
||||||
self._goal = self.np_random.uniform(0.5, 0.1, 2) if self._target is None else np.copy(self._target)
|
|
||||||
# raise NotImplementedError("How to properly sample points??")
|
total_length = np.sum(self.link_lengths)
|
||||||
|
|
||||||
|
# rejection sampled point in inner circle with 0.5*Radius
|
||||||
|
if self._via_target is None:
|
||||||
|
via_target = np.array([total_length, total_length])
|
||||||
|
while np.linalg.norm(via_target) >= 0.5 * total_length:
|
||||||
|
via_target = self.np_random.uniform(low=-0.5 * total_length, high=0.5 * total_length, size=2)
|
||||||
|
else:
|
||||||
|
via_target = np.copy(self._via_target)
|
||||||
|
|
||||||
|
# rejection sampled point in outer circle
|
||||||
|
if self._target is None:
|
||||||
|
goal = np.array([total_length, total_length])
|
||||||
|
while np.linalg.norm(goal) >= total_length or np.linalg.norm(goal) <= 0.5 * total_length:
|
||||||
|
goal = self.np_random.uniform(low=-total_length, high=total_length, size=2)
|
||||||
|
else:
|
||||||
|
goal = np.copy(self._target)
|
||||||
|
|
||||||
|
self._via_target = via_target
|
||||||
|
self._goal = goal
|
||||||
|
|
||||||
def _update_joints(self):
|
def _update_joints(self):
|
||||||
"""
|
"""
|
||||||
|
@ -70,7 +70,7 @@ class AlrMpEnvSampler:
|
|||||||
|
|
||||||
class AlrContextualMpEnvSampler:
|
class AlrContextualMpEnvSampler:
|
||||||
"""
|
"""
|
||||||
An asynchronous sampler for non contextual MPWrapper environments. A sampler object can be called with a set of
|
An asynchronous sampler for contextual MPWrapper environments. A sampler object can be called with a set of
|
||||||
parameters and returns the corresponding final obs, rewards, dones and info dicts.
|
parameters and returns the corresponding final obs, rewards, dones and info dicts.
|
||||||
"""
|
"""
|
||||||
def __init__(self, env_id, num_envs, seed=0, **env_kwargs):
|
def __init__(self, env_id, num_envs, seed=0, **env_kwargs):
|
||||||
|
Loading…
Reference in New Issue
Block a user