updates
This commit is contained in:
parent
27d7da6774
commit
1d8b22245d
@ -1,6 +1,7 @@
|
|||||||
from alr_envs.classic_control.hole_reacher import HoleReacher
|
from alr_envs.classic_control.hole_reacher import HoleReacher
|
||||||
from alr_envs.classic_control.viapoint_reacher import ViaPointReacher
|
from alr_envs.classic_control.viapoint_reacher import ViaPointReacher
|
||||||
from alr_envs.utils.dmp_env_wrapper import DmpEnvWrapper
|
from alr_envs.utils.dmp_env_wrapper import DmpEnvWrapper
|
||||||
|
from alr_envs.utils.detpmp_env_wrapper import DetPMPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
def make_viapointreacher_env(rank, seed=0):
|
def make_viapointreacher_env(rank, seed=0):
|
||||||
@ -53,7 +54,7 @@ def make_holereacher_env(rank, seed=0):
|
|||||||
hole_width=0.15,
|
hole_width=0.15,
|
||||||
hole_depth=1,
|
hole_depth=1,
|
||||||
hole_x=1,
|
hole_x=1,
|
||||||
collision_penalty=100000)
|
collision_penalty=1000)
|
||||||
|
|
||||||
_env = DmpEnvWrapper(_env,
|
_env = DmpEnvWrapper(_env,
|
||||||
num_dof=5,
|
num_dof=5,
|
||||||
@ -66,6 +67,46 @@ def make_holereacher_env(rank, seed=0):
|
|||||||
policy_type="velocity",
|
policy_type="velocity",
|
||||||
weights_scale=100,
|
weights_scale=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_env.seed(seed + rank)
|
||||||
|
return _env
|
||||||
|
|
||||||
|
return _init
|
||||||
|
|
||||||
|
|
||||||
|
def make_holereacher_env_pmp(rank, seed=0):
|
||||||
|
"""
|
||||||
|
Utility function for multiprocessed env.
|
||||||
|
|
||||||
|
:param env_id: (str) the environment ID
|
||||||
|
:param num_env: (int) the number of environments you wish to have in subprocesses
|
||||||
|
:param seed: (int) the initial seed for RNG
|
||||||
|
:param rank: (int) index of the subprocess
|
||||||
|
:returns a function that generates an environment
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _init():
|
||||||
|
_env = HoleReacher(num_links=5,
|
||||||
|
allow_self_collision=False,
|
||||||
|
allow_wall_collision=False,
|
||||||
|
hole_width=0.15,
|
||||||
|
hole_depth=1,
|
||||||
|
hole_x=1,
|
||||||
|
collision_penalty=1000)
|
||||||
|
|
||||||
|
_env = DetPMPEnvWrapper(_env,
|
||||||
|
num_dof=5,
|
||||||
|
num_basis=5,
|
||||||
|
width=0.005,
|
||||||
|
policy_type="velocity",
|
||||||
|
start_pos=_env.start_pos,
|
||||||
|
duration=2,
|
||||||
|
post_traj_time=0,
|
||||||
|
dt=_env.dt,
|
||||||
|
weights_scale=0.15,
|
||||||
|
zero_start=True,
|
||||||
|
zero_goal=False
|
||||||
|
)
|
||||||
_env.seed(seed + rank)
|
_env.seed(seed + rank)
|
||||||
return _env
|
return _env
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ def make_env(rank, seed=0):
|
|||||||
duration=3.5,
|
duration=3.5,
|
||||||
post_traj_time=4.5,
|
post_traj_time=4.5,
|
||||||
dt=env.dt,
|
dt=env.dt,
|
||||||
weights_scale=0.15,
|
weights_scale=0.25,
|
||||||
zero_start=True,
|
zero_start=True,
|
||||||
zero_goal=True
|
zero_goal=True
|
||||||
)
|
)
|
||||||
|
@ -34,7 +34,7 @@ class DetPMPEnvWrapper(gym.Wrapper):
|
|||||||
self.post_traj_steps = int(post_traj_time / dt)
|
self.post_traj_steps = int(post_traj_time / dt)
|
||||||
|
|
||||||
self.start_pos = start_pos
|
self.start_pos = start_pos
|
||||||
self.zero_centered = zero_start
|
self.zero_start = zero_start
|
||||||
|
|
||||||
policy_class = get_policy_class(policy_type)
|
policy_class = get_policy_class(policy_type)
|
||||||
self.policy = policy_class(env)
|
self.policy = policy_class(env)
|
||||||
@ -55,7 +55,7 @@ class DetPMPEnvWrapper(gym.Wrapper):
|
|||||||
params = np.reshape(params, newshape=(self.num_basis, self.num_dof)) * self.weights_scale
|
params = np.reshape(params, newshape=(self.num_basis, self.num_dof)) * self.weights_scale
|
||||||
self.pmp.set_weights(self.duration, params)
|
self.pmp.set_weights(self.duration, params)
|
||||||
t, des_pos, des_vel, des_acc = self.pmp.compute_trajectory(1 / self.dt, 1.)
|
t, des_pos, des_vel, des_acc = self.pmp.compute_trajectory(1 / self.dt, 1.)
|
||||||
if self.zero_centered:
|
if self.zero_start:
|
||||||
des_pos += self.start_pos[None, :]
|
des_pos += self.start_pos[None, :]
|
||||||
|
|
||||||
if self.post_traj_steps > 0:
|
if self.post_traj_steps > 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user