from alr_envs.utils.wrapper.detpmp_wrapper import DetPMPWrapper from alr_envs.utils.wrapper.dmp_wrapper import DmpWrapper from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv def make_contextual_env(rank, seed=0): """ Utility function for multiprocessed env. :param env_id: (str) the environment ID :param num_env: (int) the number of environments you wish to have in subprocesses :param seed: (int) the initial seed for RNG :param rank: (int) index of the subprocess :returns a function that generates an environment """ def _init(): env = ALRBallInACupEnv(reward_type="contextual_goal") env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, policy_type="motor", start_pos=env.start_pos, duration=3.5, post_traj_time=4.5, dt=env.dt, weights_scale=0.5, zero_start=True, zero_goal=True ) env.seed(seed + rank) return env return _init def make_env(rank, seed=0): """ Utility function for multiprocessed env. :param env_id: (str) the environment ID :param num_env: (int) the number of environments you wish to have in subprocesses :param seed: (int) the initial seed for RNG :param rank: (int) index of the subprocess :returns a function that generates an environment """ def _init(): env = ALRBallInACupEnv(reward_type="simple") env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, policy_type="motor", start_pos=env.start_pos, duration=3.5, post_traj_time=4.5, dt=env.dt, weights_scale=0.2, zero_start=True, zero_goal=True ) env.seed(seed + rank) return env return _init def make_simple_env(rank, seed=0): """ Utility function for multiprocessed env. :param env_id: (str) the environment ID :param num_env: (int) the number of environments you wish to have in subprocesses :param seed: (int) the initial seed for RNG :param rank: (int) index of the subprocess :returns a function that generates an environment """ def _init(): env = ALRBallInACupEnv(reward_type="simple") env = DetPMPWrapper(env, num_dof=3, num_basis=5, width=0.005, off=-0.1, policy_type="motor", start_pos=env.start_pos[1::2], duration=3.5, post_traj_time=4.5, dt=env.dt, weights_scale=0.25, zero_start=True, zero_goal=True ) env.seed(seed + rank) return env return _init def make_simple_dmp_env(rank, seed=0): """ Utility function for multiprocessed env. :param env_id: (str) the environment ID :param num_env: (int) the number of environments you wish to have in subprocesses :param seed: (int) the initial seed for RNG :param rank: (int) index of the subprocess :returns a function that generates an environment """ def _init(): _env = ALRBallInACupEnv(reward_type="simple") _env = DmpWrapper(_env, num_dof=3, num_basis=5, duration=3.5, post_traj_time=4.5, bandwidth_factor=2.5, dt=_env.dt, learn_goal=False, alpha_phase=3, start_pos=_env.start_pos[1::2], final_pos=_env.start_pos[1::2], policy_type="motor", weights_scale=100, ) _env.seed(seed + rank) return _env return _init