updated via point reacher example to new structure
This commit is contained in:
parent
fa7dfdc081
commit
f3d837349a
@ -7,9 +7,10 @@ from gym.utils import seeding
|
|||||||
|
|
||||||
from alr_envs.classic_control.utils import check_self_collision
|
from alr_envs.classic_control.utils import check_self_collision
|
||||||
from mp_env_api.envs.mp_env import MpEnv
|
from mp_env_api.envs.mp_env import MpEnv
|
||||||
|
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class ViaPointReacher(MpEnv):
|
class ViaPointReacher(gym.Env):
|
||||||
|
|
||||||
def __init__(self, n_links, random_start: bool = True, via_target: Union[None, Iterable] = None,
|
def __init__(self, n_links, random_start: bool = True, via_target: Union[None, Iterable] = None,
|
||||||
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
|
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
|
||||||
@ -20,8 +21,8 @@ class ViaPointReacher(MpEnv):
|
|||||||
self.random_start = random_start
|
self.random_start = random_start
|
||||||
|
|
||||||
# provided initial parameters
|
# provided initial parameters
|
||||||
self._target = target # provided target value
|
self.target = target # provided target value
|
||||||
self._via_target = via_target # provided via point target value
|
self.via_target = via_target # provided via point target value
|
||||||
|
|
||||||
# temp container for current env state
|
# temp container for current env state
|
||||||
self._via_point = np.ones(2)
|
self._via_point = np.ones(2)
|
||||||
@ -39,7 +40,7 @@ class ViaPointReacher(MpEnv):
|
|||||||
self._start_vel = np.zeros(self.n_links)
|
self._start_vel = np.zeros(self.n_links)
|
||||||
self.weight_matrix_scale = 1
|
self.weight_matrix_scale = 1
|
||||||
|
|
||||||
self.dt = 0.01
|
self._dt = 0.01
|
||||||
|
|
||||||
action_bound = np.pi * np.ones((self.n_links,))
|
action_bound = np.pi * np.ones((self.n_links,))
|
||||||
state_bound = np.hstack([
|
state_bound = np.hstack([
|
||||||
@ -60,6 +61,10 @@ class ViaPointReacher(MpEnv):
|
|||||||
self._steps = 0
|
self._steps = 0
|
||||||
self.seed()
|
self.seed()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dt(self):
|
||||||
|
return self._dt
|
||||||
|
|
||||||
def step(self, action: np.ndarray):
|
def step(self, action: np.ndarray):
|
||||||
"""
|
"""
|
||||||
a single step with an action in joint velocity space
|
a single step with an action in joint velocity space
|
||||||
@ -104,22 +109,22 @@ class ViaPointReacher(MpEnv):
|
|||||||
total_length = np.sum(self.link_lengths)
|
total_length = np.sum(self.link_lengths)
|
||||||
|
|
||||||
# rejection sampled point in inner circle with 0.5*Radius
|
# rejection sampled point in inner circle with 0.5*Radius
|
||||||
if self._via_target is None:
|
if self.via_target is None:
|
||||||
via_target = np.array([total_length, total_length])
|
via_target = np.array([total_length, total_length])
|
||||||
while np.linalg.norm(via_target) >= 0.5 * total_length:
|
while np.linalg.norm(via_target) >= 0.5 * total_length:
|
||||||
via_target = self.np_random.uniform(low=-0.5 * total_length, high=0.5 * total_length, size=2)
|
via_target = self.np_random.uniform(low=-0.5 * total_length, high=0.5 * total_length, size=2)
|
||||||
else:
|
else:
|
||||||
via_target = np.copy(self._via_target)
|
via_target = np.copy(self.via_target)
|
||||||
|
|
||||||
# rejection sampled point in outer circle
|
# rejection sampled point in outer circle
|
||||||
if self._target is None:
|
if self.target is None:
|
||||||
goal = np.array([total_length, total_length])
|
goal = np.array([total_length, total_length])
|
||||||
while np.linalg.norm(goal) >= total_length or np.linalg.norm(goal) <= 0.5 * total_length:
|
while np.linalg.norm(goal) >= total_length or np.linalg.norm(goal) <= 0.5 * total_length:
|
||||||
goal = self.np_random.uniform(low=-total_length, high=total_length, size=2)
|
goal = self.np_random.uniform(low=-total_length, high=total_length, size=2)
|
||||||
else:
|
else:
|
||||||
goal = np.copy(self._target)
|
goal = np.copy(self.target)
|
||||||
|
|
||||||
self._via_target = via_target
|
self.via_target = via_target
|
||||||
self._goal = goal
|
self._goal = goal
|
||||||
|
|
||||||
def _update_joints(self):
|
def _update_joints(self):
|
||||||
@ -266,25 +271,6 @@ class ViaPointReacher(MpEnv):
|
|||||||
|
|
||||||
plt.pause(0.01)
|
plt.pause(0.01)
|
||||||
|
|
||||||
@property
|
|
||||||
def active_obs(self):
|
|
||||||
return np.hstack([
|
|
||||||
[self.random_start] * self.n_links, # cos
|
|
||||||
[self.random_start] * self.n_links, # sin
|
|
||||||
[self.random_start] * self.n_links, # velocity
|
|
||||||
[self._via_target is None] * 2, # x-y coordinates of via point distance
|
|
||||||
[True] * 2, # x-y coordinates of target distance
|
|
||||||
[False] # env steps
|
|
||||||
])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def start_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
return self._start_pos
|
|
||||||
|
|
||||||
@property
|
|
||||||
def goal_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
self.np_random, seed = seeding.np_random(seed)
|
||||||
return [seed]
|
return [seed]
|
||||||
@ -298,24 +284,25 @@ class ViaPointReacher(MpEnv):
|
|||||||
plt.close(self.fig)
|
plt.close(self.fig)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
class ViaPointReacherMPWrapper(MPEnvWrapper):
|
||||||
nl = 5
|
@property
|
||||||
render_mode = "human" # "human" or "partial" or "final"
|
def active_obs(self):
|
||||||
env = ViaPointReacher(n_links=nl, allow_self_collision=False)
|
return np.hstack([
|
||||||
env.reset()
|
[self.env.random_start] * self.env.n_links, # cos
|
||||||
env.render(mode=render_mode)
|
[self.env.random_start] * self.env.n_links, # sin
|
||||||
|
[self.env.random_start] * self.env.n_links, # velocity
|
||||||
|
[self.env.via_target is None] * 2, # x-y coordinates of via point distance
|
||||||
|
[True] * 2, # x-y coordinates of target distance
|
||||||
|
[False] # env steps
|
||||||
|
])
|
||||||
|
|
||||||
for i in range(300):
|
@property
|
||||||
# objective.load_result("/tmp/cma")
|
def start_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
# test with random actions
|
return self._start_pos
|
||||||
ac = env.action_space.sample()
|
|
||||||
# ac[0] += np.pi/2
|
|
||||||
obs, rew, d, info = env.step(ac)
|
|
||||||
env.render(mode=render_mode)
|
|
||||||
|
|
||||||
print(rew)
|
@property
|
||||||
|
def goal_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||||
|
|
||||||
if d:
|
def dt(self) -> Union[float, int]:
|
||||||
break
|
return self.env.dt
|
||||||
|
|
||||||
env.close()
|
|
||||||
|
Loading…
Reference in New Issue
Block a user