reacher envs pass unittests
This commit is contained in:
parent
cfa49a04ba
commit
ca90f257d4
@ -237,7 +237,7 @@ for _v in _versions:
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2 if "long" not in _v.lower() else 5,
|
||||
"num_basis": 5,
|
||||
"duration": 20,
|
||||
"duration": 2,
|
||||
"alpha_phase": 2,
|
||||
"learn_goal": True,
|
||||
"policy_type": "velocity",
|
||||
@ -277,7 +277,7 @@ for _v in _versions:
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2 if "long" not in _v.lower() else 5,
|
||||
"num_basis": 5,
|
||||
"duration": 20,
|
||||
"duration": 2,
|
||||
"width": 0.025,
|
||||
"policy_type": "velocity",
|
||||
"weights_scale": 0.2,
|
||||
|
@ -22,6 +22,8 @@ class BaseReacherEnv(gym.Env, ABC):
|
||||
|
||||
self.random_start = random_start
|
||||
|
||||
self.allow_self_collision = allow_self_collision
|
||||
|
||||
# state
|
||||
self._joints = None
|
||||
self._joint_angles = None
|
||||
@ -103,6 +105,9 @@ class BaseReacherEnv(gym.Env, ABC):
|
||||
def _check_self_collision(self):
|
||||
"""Checks whether line segments intersect"""
|
||||
|
||||
if self.allow_self_collision:
|
||||
return False
|
||||
|
||||
if np.any(self._joint_angles > self.j_max) or np.any(self._joint_angles < self.j_min):
|
||||
return True
|
||||
|
||||
|
@ -127,6 +127,9 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
||||
|
||||
return np.squeeze(end_effector + self._joints[0, :])
|
||||
|
||||
def _check_collisions(self) -> bool:
|
||||
return self._check_self_collision() or self.check_wall_collision()
|
||||
|
||||
def check_wall_collision(self):
|
||||
line_points = self._get_line_points(num_points_per_link=100)
|
||||
|
||||
@ -223,6 +226,6 @@ if __name__ == "__main__":
|
||||
for i in range(10000):
|
||||
ac = env.action_space.sample()
|
||||
obs, rew, done, info = env.step(ac)
|
||||
# env.render()
|
||||
env.render()
|
||||
if done:
|
||||
env.reset()
|
||||
|
@ -88,6 +88,9 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
|
||||
|
||||
self._goal = goal
|
||||
|
||||
def _check_collisions(self) -> bool:
|
||||
return self._check_self_collision()
|
||||
|
||||
def render(self, mode='human'): # pragma: no cover
|
||||
if self.fig is None:
|
||||
# Create base figure once on the beginning. Afterwards only update
|
||||
|
@ -25,7 +25,6 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
|
||||
self._goal = np.array((n_links, 0))
|
||||
|
||||
# collision
|
||||
self.allow_self_collision = allow_self_collision
|
||||
self.collision_penalty = collision_penalty
|
||||
|
||||
state_bound = np.hstack([
|
||||
@ -96,7 +95,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
|
||||
reward -= 5e-8 * np.sum(acc ** 2)
|
||||
info = {"is_success": success,
|
||||
"is_collided": self._is_collided,
|
||||
"end_effector": np.copy(env.end_effector)}
|
||||
"end_effector": np.copy(self.end_effector)}
|
||||
|
||||
return reward, info
|
||||
|
||||
@ -114,6 +113,9 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
|
||||
self._steps
|
||||
])
|
||||
|
||||
def _check_collisions(self) -> bool:
|
||||
return self._check_self_collision()
|
||||
|
||||
def render(self, mode='human'):
|
||||
goal_pos = self._goal.T
|
||||
via_pos = self._via_point.T
|
||||
|
@ -87,6 +87,7 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
[self._steps],
|
||||
])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
nl = 5
|
||||
render_mode = "human" # "human" or "partial" or "final"
|
||||
|
@ -112,7 +112,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
||||
|
||||
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
|
||||
# You can also add other gym.Wrappers in case they are needed.
|
||||
wrappers = [alr_envs.classic_control.hole_reacher.MPWrapper]
|
||||
wrappers = [alr_envs.alr.classic_control.hole_reacher.MPWrapper]
|
||||
mp_kwargs = {
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
@ -147,15 +147,15 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
render = False
|
||||
render = True
|
||||
# DMP
|
||||
example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=10, render=render)
|
||||
example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
|
||||
|
||||
# ProMP
|
||||
example_mp("alr_envs:HoleReacherProMP-v1", seed=10, iterations=100, render=render)
|
||||
example_mp("alr_envs:HoleReacherProMP-v1", seed=10, iterations=1, render=render)
|
||||
|
||||
# DetProMP
|
||||
example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=100, render=render)
|
||||
example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render)
|
||||
|
||||
# Altered basis functions
|
||||
example_custom_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
|
||||
|
Loading…
Reference in New Issue
Block a user