fancy_gym/alr_envs/classic_control/utils.py

156 lines
5.2 KiB
Python
Raw Normal View History

2021-01-14 17:10:03 +01:00
from alr_envs.classic_control.hole_reacher import HoleReacher
2021-02-15 09:03:19 +01:00
from alr_envs.classic_control.viapoint_reacher import ViaPointReacher
from alr_envs.utils.wrapper.dmp_wrapper import DmpWrapper
from alr_envs.utils.wrapper.detpmp_wrapper import DetPMPWrapper
2021-03-19 16:31:46 +01:00
import numpy as np
2021-01-14 17:10:03 +01:00
2021-02-15 09:03:19 +01:00
def make_viapointreacher_env(rank, seed=0):
2021-01-14 17:10:03 +01:00
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
_env = ViaPointReacher(n_links=5,
2021-02-15 09:03:19 +01:00
allow_self_collision=False,
collision_penalty=1000)
_env = DmpWrapper(_env,
num_dof=5,
num_basis=5,
duration=2,
alpha_phase=2.5,
dt=_env.dt,
start_pos=_env.start_pos,
learn_goal=False,
policy_type="velocity",
weights_scale=50)
2021-02-15 09:03:19 +01:00
_env.seed(seed + rank)
return _env
return _init
def make_holereacher_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
_env = HoleReacher(n_links=5,
2021-02-15 09:03:19 +01:00
allow_self_collision=False,
allow_wall_collision=False,
2021-04-10 13:25:08 +02:00
hole_width=0.25,
2021-02-15 09:03:19 +01:00
hole_depth=1,
2021-04-10 13:25:08 +02:00
hole_x=2,
2021-03-19 16:31:46 +01:00
collision_penalty=100)
2021-02-15 09:03:19 +01:00
_env = DmpWrapper(_env,
num_dof=5,
num_basis=5,
duration=2,
dt=_env.dt,
learn_goal=True,
2021-04-19 11:53:30 +02:00
alpha_phase=2,
start_pos=_env.start_pos,
policy_type="velocity",
2021-04-19 11:53:30 +02:00
weights_scale=50,
goal_scale=0.1
)
2021-03-19 16:31:46 +01:00
_env.seed(seed + rank)
return _env
return _init
def make_holereacher_fix_goal_env(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
_env = HoleReacher(n_links=5,
2021-03-19 16:31:46 +01:00
allow_self_collision=False,
allow_wall_collision=False,
hole_width=0.15,
hole_depth=1,
hole_x=1,
collision_penalty=100)
_env = DmpWrapper(_env,
num_dof=5,
num_basis=5,
duration=2,
dt=_env.dt,
learn_goal=False,
final_pos=np.array([2.02669572, -1.25966385, -1.51618198, -0.80946476, 0.02012344]),
2021-04-19 11:53:30 +02:00
alpha_phase=2,
start_pos=_env.start_pos,
policy_type="velocity",
weights_scale=50,
goal_scale=1
)
2021-02-26 17:34:31 +01:00
_env.seed(seed + rank)
return _env
return _init
def make_holereacher_env_pmp(rank, seed=0):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the initial seed for RNG
:param rank: (int) index of the subprocess
:returns a function that generates an environment
"""
def _init():
_env = HoleReacher(n_links=5,
2021-02-26 17:34:31 +01:00
allow_self_collision=False,
allow_wall_collision=False,
hole_width=0.15,
hole_depth=1,
hole_x=1,
collision_penalty=1000)
_env = DetPMPWrapper(_env,
num_dof=5,
num_basis=5,
2021-04-19 11:53:30 +02:00
width=0.02,
policy_type="velocity",
start_pos=_env.start_pos,
duration=2,
post_traj_time=0,
dt=_env.dt,
2021-04-19 11:53:30 +02:00
weights_scale=0.2,
zero_start=True,
zero_goal=False
)
2021-02-15 09:03:19 +01:00
_env.seed(seed + rank)
return _env
2021-01-14 17:10:03 +01:00
return _init