This commit is contained in:
Maximilian Huettenrauch 2021-04-08 19:00:48 +02:00
parent f7c2f800ed
commit 1fd44616bf
2 changed files with 5 additions and 3 deletions

View File

@ -88,13 +88,14 @@ def make_simple_env(rank, seed=0):
env = DetPMPEnvWrapper(env, env = DetPMPEnvWrapper(env,
num_dof=3, num_dof=3,
num_basis=5, num_basis=5,
width=0.01, width=0.005,
off=-0.1,
policy_type="motor", policy_type="motor",
start_pos=env.start_pos[1::2], start_pos=env.start_pos[1::2],
duration=3.5, duration=3.5,
post_traj_time=4.5, post_traj_time=4.5,
dt=env.dt, dt=env.dt,
weights_scale=0.5, weights_scale=0.25,
zero_start=True, zero_start=True,
zero_goal=True zero_goal=True
) )

View File

@ -10,6 +10,7 @@ class DetPMPEnvWrapper(gym.Wrapper):
num_dof, num_dof,
num_basis, num_basis,
width, width,
off=0.01,
start_pos=None, start_pos=None,
duration=1, duration=1,
dt=0.01, dt=0.01,
@ -23,7 +24,7 @@ class DetPMPEnvWrapper(gym.Wrapper):
self.num_dof = num_dof self.num_dof = num_dof
self.num_basis = num_basis self.num_basis = num_basis
self.dim = num_dof * num_basis self.dim = num_dof * num_basis
self.pmp = det_promp.DeterministicProMP(n_basis=num_basis, n_dof=num_dof, width=width, off=0.01, self.pmp = det_promp.DeterministicProMP(n_basis=num_basis, n_dof=num_dof, width=width, off=off,
zero_start=zero_start, zero_goal=zero_goal) zero_start=zero_start, zero_goal=zero_goal)
weights = np.zeros(shape=(num_basis, num_dof)) weights = np.zeros(shape=(num_basis, num_dof))
self.pmp.set_weights(duration, weights) self.pmp.set_weights(duration, weights)