diff --git a/alr_envs/mujoco/ball_in_a_cup/utils.py b/alr_envs/mujoco/ball_in_a_cup/utils.py index f889736..2e90404 100644 --- a/alr_envs/mujoco/ball_in_a_cup/utils.py +++ b/alr_envs/mujoco/ball_in_a_cup/utils.py @@ -88,13 +88,14 @@ def make_simple_env(rank, seed=0): env = DetPMPEnvWrapper(env, num_dof=3, num_basis=5, - width=0.01, + width=0.005, + off=-0.1, policy_type="motor", start_pos=env.start_pos[1::2], duration=3.5, post_traj_time=4.5, dt=env.dt, - weights_scale=0.5, + weights_scale=0.25, zero_start=True, zero_goal=True ) diff --git a/alr_envs/utils/detpmp_env_wrapper.py b/alr_envs/utils/detpmp_env_wrapper.py index f49862e..c667abf 100644 --- a/alr_envs/utils/detpmp_env_wrapper.py +++ b/alr_envs/utils/detpmp_env_wrapper.py @@ -10,6 +10,7 @@ class DetPMPEnvWrapper(gym.Wrapper): num_dof, num_basis, width, + off=0.01, start_pos=None, duration=1, dt=0.01, @@ -23,7 +24,7 @@ class DetPMPEnvWrapper(gym.Wrapper): self.num_dof = num_dof self.num_basis = num_basis self.dim = num_dof * num_basis - self.pmp = det_promp.DeterministicProMP(n_basis=num_basis, n_dof=num_dof, width=width, off=0.01, + self.pmp = det_promp.DeterministicProMP(n_basis=num_basis, n_dof=num_dof, width=width, off=off, zero_start=zero_start, zero_goal=zero_goal) weights = np.zeros(shape=(num_basis, num_dof)) self.pmp.set_weights(duration, weights)