diff --git a/alr_envs/classic_control/utils.py b/alr_envs/classic_control/utils.py index f2ead72..3540ad2 100644 --- a/alr_envs/classic_control/utils.py +++ b/alr_envs/classic_control/utils.py @@ -1,6 +1,7 @@ from alr_envs.classic_control.hole_reacher import HoleReacher from alr_envs.classic_control.viapoint_reacher import ViaPointReacher from alr_envs.utils.dmp_env_wrapper import DmpEnvWrapper +from alr_envs.utils.detpmp_env_wrapper import DetPMPEnvWrapper def make_viapointreacher_env(rank, seed=0): @@ -53,7 +54,7 @@ def make_holereacher_env(rank, seed=0): hole_width=0.15, hole_depth=1, hole_x=1, - collision_penalty=100000) + collision_penalty=1000) _env = DmpEnvWrapper(_env, num_dof=5, @@ -66,6 +67,46 @@ def make_holereacher_env(rank, seed=0): policy_type="velocity", weights_scale=100, ) + + _env.seed(seed + rank) + return _env + + return _init + + +def make_holereacher_env_pmp(rank, seed=0): + """ + Utility function for multiprocessed env. + + :param env_id: (str) the environment ID + :param num_env: (int) the number of environments you wish to have in subprocesses + :param seed: (int) the initial seed for RNG + :param rank: (int) index of the subprocess + :returns a function that generates an environment + """ + + def _init(): + _env = HoleReacher(num_links=5, + allow_self_collision=False, + allow_wall_collision=False, + hole_width=0.15, + hole_depth=1, + hole_x=1, + collision_penalty=1000) + + _env = DetPMPEnvWrapper(_env, + num_dof=5, + num_basis=5, + width=0.005, + policy_type="velocity", + start_pos=_env.start_pos, + duration=2, + post_traj_time=0, + dt=_env.dt, + weights_scale=0.15, + zero_start=True, + zero_goal=False + ) _env.seed(seed + rank) return _env diff --git a/alr_envs/mujoco/ball_in_a_cup/utils.py b/alr_envs/mujoco/ball_in_a_cup/utils.py index d65b1fe..876875c 100644 --- a/alr_envs/mujoco/ball_in_a_cup/utils.py +++ b/alr_envs/mujoco/ball_in_a_cup/utils.py @@ -60,7 +60,7 @@ def make_env(rank, seed=0): duration=3.5, post_traj_time=4.5, dt=env.dt, - weights_scale=0.15, + weights_scale=0.25, zero_start=True, zero_goal=True ) diff --git a/alr_envs/utils/detpmp_env_wrapper.py b/alr_envs/utils/detpmp_env_wrapper.py index 3e72c06..f49862e 100644 --- a/alr_envs/utils/detpmp_env_wrapper.py +++ b/alr_envs/utils/detpmp_env_wrapper.py @@ -34,7 +34,7 @@ class DetPMPEnvWrapper(gym.Wrapper): self.post_traj_steps = int(post_traj_time / dt) self.start_pos = start_pos - self.zero_centered = zero_start + self.zero_start = zero_start policy_class = get_policy_class(policy_type) self.policy = policy_class(env) @@ -55,7 +55,7 @@ class DetPMPEnvWrapper(gym.Wrapper): params = np.reshape(params, newshape=(self.num_basis, self.num_dof)) * self.weights_scale self.pmp.set_weights(self.duration, params) t, des_pos, des_vel, des_acc = self.pmp.compute_trajectory(1 / self.dt, 1.) - if self.zero_centered: + if self.zero_start: des_pos += self.start_pos[None, :] if self.post_traj_steps > 0: