fixed open issues
This commit is contained in:
parent
b39104a449
commit
14c60766c2
@ -214,16 +214,17 @@ register(
|
||||
)
|
||||
|
||||
# MP environments
|
||||
reacher_envs = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"]
|
||||
for env in reacher_envs:
|
||||
name = env.split("-")
|
||||
## Simple Reacher
|
||||
versions = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"]
|
||||
for v in versions:
|
||||
name = v.split("-")
|
||||
register(
|
||||
id=f'{name[0]}DMP-{name[1]}',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": f"alr_envs:{env}",
|
||||
"num_dof": 2 if "long" not in env.lower() else 5 ,
|
||||
"name": f"alr_envs:{v}",
|
||||
"num_dof": 2 if "long" not in v.lower() else 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"alpha_phase": 2,
|
||||
@ -249,12 +250,15 @@ register(
|
||||
}
|
||||
)
|
||||
|
||||
## Hole Reacher
|
||||
versions = ["v0", "v1", "v2"]
|
||||
for v in versions:
|
||||
register(
|
||||
id='HoleReacherDMP-v0',
|
||||
id=f'HoleReacherDMP-{v}',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": "alr_envs:HoleReacher-v0",
|
||||
"name": f"alr_envs:HoleReacher-{v}",
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
@ -267,41 +271,12 @@ register(
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='HoleReacherDMP-v1',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": "alr_envs:HoleReacher-v1",
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"learn_goal": True,
|
||||
"alpha_phase": 2,
|
||||
"bandwidth_factor": 2,
|
||||
"policy_type": "velocity",
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='HoleReacherDMP-v2',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": "alr_envs:HoleReacher-v2",
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"learn_goal": True,
|
||||
"alpha_phase": 2,
|
||||
"bandwidth_factor": 2,
|
||||
"policy_type": "velocity",
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1
|
||||
}
|
||||
)
|
||||
# register(
|
||||
# id='HoleReacherDetPMP-v0',
|
||||
# entry_point='alr_envs.classic_control.hole_reacher:holereacher_detpmp',
|
||||
# # max_episode_steps=1,
|
||||
# # TODO: add mp kwargs
|
||||
# )
|
||||
|
||||
# TODO: properly add final_pos
|
||||
register(
|
||||
@ -321,12 +296,7 @@ register(
|
||||
}
|
||||
)
|
||||
|
||||
# register(
|
||||
# id='HoleReacherDetPMP-v0',
|
||||
# entry_point='alr_envs.classic_control.hole_reacher:holereacher_detpmp',
|
||||
# # max_episode_steps=1,
|
||||
# # TODO: add mp kwargs
|
||||
# )
|
||||
## Ball in Cup
|
||||
|
||||
register(
|
||||
id='ALRBallInACupSimpleDMP-v0',
|
||||
|
@ -127,12 +127,6 @@ class SimpleReacherEnv(MPEnv):
|
||||
def _generate_goal(self):
|
||||
|
||||
if self._target is None:
|
||||
# center = self._joints[0]
|
||||
# # Sample uniformly in circle with radius R around center of reacher.
|
||||
# R = np.sum(self.link_lengths)
|
||||
# r = R * np.sqrt(self.np_random.uniform())
|
||||
# theta = self.np_random.uniform() * 2 * np.pi
|
||||
# goal = center + r * np.stack([np.cos(theta), np.sin(theta)])
|
||||
|
||||
total_length = np.sum(self.link_lengths)
|
||||
goal = np.array([total_length, total_length])
|
||||
|
@ -99,9 +99,28 @@ class ViaPointReacher(MPEnv):
|
||||
return self._get_obs().copy()
|
||||
|
||||
def _generate_goal(self):
|
||||
self._via_point = self.np_random.uniform(0.5, 3.5, 2) if self._via_target is None else np.copy(self._via_target)
|
||||
self._goal = self.np_random.uniform(0.5, 0.1, 2) if self._target is None else np.copy(self._target)
|
||||
# raise NotImplementedError("How to properly sample points??")
|
||||
# TODO: Maybe improve this later, this can yield quite a lot of invalid settings
|
||||
|
||||
total_length = np.sum(self.link_lengths)
|
||||
|
||||
# rejection sampled point in inner circle with 0.5*Radius
|
||||
if self._via_target is None:
|
||||
via_target = np.array([total_length, total_length])
|
||||
while np.linalg.norm(via_target) >= 0.5 * total_length:
|
||||
via_target = self.np_random.uniform(low=-0.5 * total_length, high=0.5 * total_length, size=2)
|
||||
else:
|
||||
via_target = np.copy(self._via_target)
|
||||
|
||||
# rejection sampled point in outer circle
|
||||
if self._target is None:
|
||||
goal = np.array([total_length, total_length])
|
||||
while np.linalg.norm(goal) >= total_length or np.linalg.norm(goal) <= 0.5 * total_length:
|
||||
goal = self.np_random.uniform(low=-total_length, high=total_length, size=2)
|
||||
else:
|
||||
goal = np.copy(self._target)
|
||||
|
||||
self._via_target = via_target
|
||||
self._goal = goal
|
||||
|
||||
def _update_joints(self):
|
||||
"""
|
||||
|
@ -70,7 +70,7 @@ class AlrMpEnvSampler:
|
||||
|
||||
class AlrContextualMpEnvSampler:
|
||||
"""
|
||||
An asynchronous sampler for non contextual MPWrapper environments. A sampler object can be called with a set of
|
||||
An asynchronous sampler for contextual MPWrapper environments. A sampler object can be called with a set of
|
||||
parameters and returns the corresponding final obs, rewards, dones and info dicts.
|
||||
"""
|
||||
def __init__(self, env_id, num_envs, seed=0, **env_kwargs):
|
||||
|
Loading…
Reference in New Issue
Block a user