reacher envs pass unittests
This commit is contained in:
parent
cfa49a04ba
commit
ca90f257d4
@ -237,7 +237,7 @@ for _v in _versions:
|
|||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 2 if "long" not in _v.lower() else 5,
|
"num_dof": 2 if "long" not in _v.lower() else 5,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 20,
|
"duration": 2,
|
||||||
"alpha_phase": 2,
|
"alpha_phase": 2,
|
||||||
"learn_goal": True,
|
"learn_goal": True,
|
||||||
"policy_type": "velocity",
|
"policy_type": "velocity",
|
||||||
@ -277,7 +277,7 @@ for _v in _versions:
|
|||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 2 if "long" not in _v.lower() else 5,
|
"num_dof": 2 if "long" not in _v.lower() else 5,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 20,
|
"duration": 2,
|
||||||
"width": 0.025,
|
"width": 0.025,
|
||||||
"policy_type": "velocity",
|
"policy_type": "velocity",
|
||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
|
@ -22,6 +22,8 @@ class BaseReacherEnv(gym.Env, ABC):
|
|||||||
|
|
||||||
self.random_start = random_start
|
self.random_start = random_start
|
||||||
|
|
||||||
|
self.allow_self_collision = allow_self_collision
|
||||||
|
|
||||||
# state
|
# state
|
||||||
self._joints = None
|
self._joints = None
|
||||||
self._joint_angles = None
|
self._joint_angles = None
|
||||||
@ -103,6 +105,9 @@ class BaseReacherEnv(gym.Env, ABC):
|
|||||||
def _check_self_collision(self):
|
def _check_self_collision(self):
|
||||||
"""Checks whether line segments intersect"""
|
"""Checks whether line segments intersect"""
|
||||||
|
|
||||||
|
if self.allow_self_collision:
|
||||||
|
return False
|
||||||
|
|
||||||
if np.any(self._joint_angles > self.j_max) or np.any(self._joint_angles < self.j_min):
|
if np.any(self._joint_angles > self.j_max) or np.any(self._joint_angles < self.j_min):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -127,6 +127,9 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
|||||||
|
|
||||||
return np.squeeze(end_effector + self._joints[0, :])
|
return np.squeeze(end_effector + self._joints[0, :])
|
||||||
|
|
||||||
|
def _check_collisions(self) -> bool:
|
||||||
|
return self._check_self_collision() or self.check_wall_collision()
|
||||||
|
|
||||||
def check_wall_collision(self):
|
def check_wall_collision(self):
|
||||||
line_points = self._get_line_points(num_points_per_link=100)
|
line_points = self._get_line_points(num_points_per_link=100)
|
||||||
|
|
||||||
@ -223,6 +226,6 @@ if __name__ == "__main__":
|
|||||||
for i in range(10000):
|
for i in range(10000):
|
||||||
ac = env.action_space.sample()
|
ac = env.action_space.sample()
|
||||||
obs, rew, done, info = env.step(ac)
|
obs, rew, done, info = env.step(ac)
|
||||||
# env.render()
|
env.render()
|
||||||
if done:
|
if done:
|
||||||
env.reset()
|
env.reset()
|
||||||
|
@ -88,6 +88,9 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
|
|||||||
|
|
||||||
self._goal = goal
|
self._goal = goal
|
||||||
|
|
||||||
|
def _check_collisions(self) -> bool:
|
||||||
|
return self._check_self_collision()
|
||||||
|
|
||||||
def render(self, mode='human'): # pragma: no cover
|
def render(self, mode='human'): # pragma: no cover
|
||||||
if self.fig is None:
|
if self.fig is None:
|
||||||
# Create base figure once on the beginning. Afterwards only update
|
# Create base figure once on the beginning. Afterwards only update
|
||||||
|
@ -25,7 +25,6 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
|
|||||||
self._goal = np.array((n_links, 0))
|
self._goal = np.array((n_links, 0))
|
||||||
|
|
||||||
# collision
|
# collision
|
||||||
self.allow_self_collision = allow_self_collision
|
|
||||||
self.collision_penalty = collision_penalty
|
self.collision_penalty = collision_penalty
|
||||||
|
|
||||||
state_bound = np.hstack([
|
state_bound = np.hstack([
|
||||||
@ -96,7 +95,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
|
|||||||
reward -= 5e-8 * np.sum(acc ** 2)
|
reward -= 5e-8 * np.sum(acc ** 2)
|
||||||
info = {"is_success": success,
|
info = {"is_success": success,
|
||||||
"is_collided": self._is_collided,
|
"is_collided": self._is_collided,
|
||||||
"end_effector": np.copy(env.end_effector)}
|
"end_effector": np.copy(self.end_effector)}
|
||||||
|
|
||||||
return reward, info
|
return reward, info
|
||||||
|
|
||||||
@ -114,6 +113,9 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
|
|||||||
self._steps
|
self._steps
|
||||||
])
|
])
|
||||||
|
|
||||||
|
def _check_collisions(self) -> bool:
|
||||||
|
return self._check_self_collision()
|
||||||
|
|
||||||
def render(self, mode='human'):
|
def render(self, mode='human'):
|
||||||
goal_pos = self._goal.T
|
goal_pos = self._goal.T
|
||||||
via_pos = self._via_point.T
|
via_pos = self._via_point.T
|
||||||
|
@ -87,6 +87,7 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle):
|
|||||||
[self._steps],
|
[self._steps],
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
nl = 5
|
nl = 5
|
||||||
render_mode = "human" # "human" or "partial" or "final"
|
render_mode = "human" # "human" or "partial" or "final"
|
||||||
|
@ -112,7 +112,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
|
|
||||||
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
|
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
|
||||||
# You can also add other gym.Wrappers in case they are needed.
|
# You can also add other gym.Wrappers in case they are needed.
|
||||||
wrappers = [alr_envs.classic_control.hole_reacher.MPWrapper]
|
wrappers = [alr_envs.alr.classic_control.hole_reacher.MPWrapper]
|
||||||
mp_kwargs = {
|
mp_kwargs = {
|
||||||
"num_dof": 5,
|
"num_dof": 5,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -147,15 +147,15 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = False
|
render = True
|
||||||
# DMP
|
# DMP
|
||||||
example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=10, render=render)
|
example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
# ProMP
|
# ProMP
|
||||||
example_mp("alr_envs:HoleReacherProMP-v1", seed=10, iterations=100, render=render)
|
example_mp("alr_envs:HoleReacherProMP-v1", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
# DetProMP
|
# DetProMP
|
||||||
example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=100, render=render)
|
example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
# Altered basis functions
|
# Altered basis functions
|
||||||
example_custom_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
|
example_custom_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
|
||||||
|
Loading…
Reference in New Issue
Block a user