2021-05-12 09:52:25 +02:00
|
|
|
from alr_envs.utils.mps.detpmp_wrapper import DetPMPWrapper
|
|
|
|
from alr_envs.utils.mps.dmp_wrapper import DmpWrapper
|
2021-02-11 16:19:57 +01:00
|
|
|
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
|
|
|
|
|
|
|
|
|
2021-02-24 15:37:54 +01:00
|
|
|
def make_contextual_env(rank, seed=0):
|
2021-02-16 15:47:32 +01:00
|
|
|
"""
|
|
|
|
Utility function for multiprocessed env.
|
|
|
|
|
|
|
|
:param env_id: (str) the environment ID
|
|
|
|
:param num_env: (int) the number of environments you wish to have in subprocesses
|
|
|
|
:param seed: (int) the initial seed for RNG
|
|
|
|
:param rank: (int) index of the subprocess
|
|
|
|
:returns a function that generates an environment
|
|
|
|
"""
|
|
|
|
|
|
|
|
def _init():
|
2021-04-23 11:37:42 +02:00
|
|
|
env = ALRBallInACupEnv(reward_type="contextual_goal")
|
2021-02-16 15:47:32 +01:00
|
|
|
|
2021-05-12 17:48:57 +02:00
|
|
|
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
|
|
|
|
policy_type="motor", weights_scale=0.5, zero_start=True, zero_goal=True)
|
2021-02-24 15:37:54 +01:00
|
|
|
|
|
|
|
env.seed(seed + rank)
|
|
|
|
return env
|
|
|
|
|
|
|
|
return _init
|
|
|
|
|
|
|
|
|
2021-06-30 15:00:36 +02:00
|
|
|
def _make_env(rank, seed=0):
|
2021-02-24 15:37:54 +01:00
|
|
|
"""
|
|
|
|
Utility function for multiprocessed env.
|
|
|
|
|
|
|
|
:param env_id: (str) the environment ID
|
|
|
|
:param num_env: (int) the number of environments you wish to have in subprocesses
|
|
|
|
:param seed: (int) the initial seed for RNG
|
|
|
|
:param rank: (int) index of the subprocess
|
|
|
|
:returns a function that generates an environment
|
|
|
|
"""
|
|
|
|
|
|
|
|
def _init():
|
2021-04-23 11:37:42 +02:00
|
|
|
env = ALRBallInACupEnv(reward_type="simple")
|
2021-02-24 15:37:54 +01:00
|
|
|
|
2021-05-12 17:48:57 +02:00
|
|
|
env = DetPMPWrapper(env, num_dof=7, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
|
|
|
|
policy_type="motor", weights_scale=0.2, zero_start=True, zero_goal=True)
|
2021-02-16 15:47:32 +01:00
|
|
|
|
|
|
|
env.seed(seed + rank)
|
|
|
|
return env
|
|
|
|
|
|
|
|
return _init
|
|
|
|
|
2021-02-11 16:19:57 +01:00
|
|
|
|
|
|
|
def make_simple_env(rank, seed=0):
|
|
|
|
"""
|
|
|
|
Utility function for multiprocessed env.
|
|
|
|
|
|
|
|
:param env_id: (str) the environment ID
|
|
|
|
:param num_env: (int) the number of environments you wish to have in subprocesses
|
|
|
|
:param seed: (int) the initial seed for RNG
|
|
|
|
:param rank: (int) index of the subprocess
|
|
|
|
:returns a function that generates an environment
|
|
|
|
"""
|
|
|
|
|
|
|
|
def _init():
|
2021-04-23 11:37:42 +02:00
|
|
|
env = ALRBallInACupEnv(reward_type="simple")
|
2021-02-11 16:19:57 +01:00
|
|
|
|
2021-05-12 17:48:57 +02:00
|
|
|
env = DetPMPWrapper(env, num_dof=3, num_basis=5, width=0.005, duration=3.5, dt=env.dt, post_traj_time=4.5,
|
|
|
|
policy_type="motor", weights_scale=0.25, zero_start=True, zero_goal=True, off=-0.1)
|
2021-02-11 16:19:57 +01:00
|
|
|
|
|
|
|
env.seed(seed + rank)
|
|
|
|
return env
|
|
|
|
|
|
|
|
return _init
|
2021-04-10 19:11:32 +02:00
|
|
|
|
|
|
|
|
|
|
|
def make_simple_dmp_env(rank, seed=0):
|
|
|
|
"""
|
|
|
|
Utility function for multiprocessed env.
|
|
|
|
|
|
|
|
:param env_id: (str) the environment ID
|
|
|
|
:param num_env: (int) the number of environments you wish to have in subprocesses
|
|
|
|
:param seed: (int) the initial seed for RNG
|
|
|
|
:param rank: (int) index of the subprocess
|
|
|
|
:returns a function that generates an environment
|
|
|
|
"""
|
|
|
|
|
|
|
|
def _init():
|
2021-04-23 11:37:42 +02:00
|
|
|
_env = ALRBallInACupEnv(reward_type="simple")
|
2021-04-10 19:11:32 +02:00
|
|
|
|
2021-04-19 11:53:30 +02:00
|
|
|
_env = DmpWrapper(_env,
|
|
|
|
num_dof=3,
|
|
|
|
num_basis=5,
|
|
|
|
duration=3.5,
|
|
|
|
post_traj_time=4.5,
|
|
|
|
bandwidth_factor=2.5,
|
|
|
|
dt=_env.dt,
|
|
|
|
learn_goal=False,
|
|
|
|
alpha_phase=3,
|
|
|
|
start_pos=_env.start_pos[1::2],
|
|
|
|
final_pos=_env.start_pos[1::2],
|
|
|
|
policy_type="motor",
|
|
|
|
weights_scale=100,
|
|
|
|
)
|
2021-04-10 19:11:32 +02:00
|
|
|
|
|
|
|
_env.seed(seed + rank)
|
|
|
|
return _env
|
|
|
|
|
|
|
|
return _init
|