wip
This commit is contained in:
parent
f3d75b9a60
commit
e482fc09f0
@ -1,7 +1,7 @@
|
|||||||
from gym.envs.registration import register
|
from gym.envs.registration import register
|
||||||
|
|
||||||
from alr_envs.stochastic_search.functions.f_rosenbrock import Rosenbrock
|
from alr_envs.stochastic_search.functions.f_rosenbrock import Rosenbrock
|
||||||
from alr_envs.utils.wrapper.dmp_wrapper import DmpWrapper
|
# from alr_envs.utils.wrapper.dmp_wrapper import DmpWrapper
|
||||||
|
|
||||||
# Mujoco
|
# Mujoco
|
||||||
|
|
||||||
@ -119,37 +119,89 @@ register(
|
|||||||
"n_links": 5,
|
"n_links": 5,
|
||||||
"allow_self_collision": False,
|
"allow_self_collision": False,
|
||||||
"allow_wall_collision": False,
|
"allow_wall_collision": False,
|
||||||
"hole_width": 0.15,
|
"hole_width": 0.25,
|
||||||
"hole_depth": 1,
|
"hole_depth": 1,
|
||||||
"hole_x": 1,
|
"hole_x": 2,
|
||||||
"collision_penalty": 100,
|
"collision_penalty": 100,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# DMP environments
|
# MP environments
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ViaPointReacherDMP-v0',
|
id='ViaPointReacherDMP-v0',
|
||||||
entry_point='alr_envs.classic_control.viapoint_reacher:viapoint_dmp',
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
|
kwargs={
|
||||||
|
"name": "alr_envs:ViaPointReacher-v0",
|
||||||
|
"num_dof": 5,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 2,
|
||||||
|
"alpha_phase": 2,
|
||||||
|
"learn_goal": False,
|
||||||
|
"policy_type": "velocity",
|
||||||
|
"weights_scale": 50,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='HoleReacherDMP-v0',
|
id='HoleReacherDMP-v0',
|
||||||
entry_point='alr_envs.classic_control.hole_reacher:holereacher_dmp',
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
|
kwargs={
|
||||||
|
"name": "alr_envs:HoleReacher-v0",
|
||||||
|
"num_dof": 5,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 2,
|
||||||
|
"learn_goal": True,
|
||||||
|
"alpha_phase": 2,
|
||||||
|
"bandwidth_factor": 2,
|
||||||
|
"policy_type": "velocity",
|
||||||
|
"weights_scale": 50,
|
||||||
|
"goal_scale": 0.1
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: properly add final_pos
|
||||||
register(
|
register(
|
||||||
id='HoleReacherFixedGoalDMP-v0',
|
id='HoleReacherFixedGoalDMP-v0',
|
||||||
entry_point='alr_envs.classic_control.hole_reacher:holereacher_fix_goal_dmp',
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
|
kwargs={
|
||||||
|
"name": "alr_envs:HoleReacher-v0",
|
||||||
|
"num_dof": 5,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 2,
|
||||||
|
"learn_goal": False,
|
||||||
|
"alpha_phase": 2,
|
||||||
|
"policy_type": "velocity",
|
||||||
|
"weights_scale": 50,
|
||||||
|
"goal_scale": 0.1
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='HoleReacherDetPMP-v0',
|
id='HoleReacherDetPMP-v0',
|
||||||
entry_point='alr_envs.classic_control.hole_reacher:holereacher_detpmp',
|
entry_point='alr_envs.classic_control.hole_reacher:holereacher_detpmp',
|
||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
|
# TODO: add mp kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id='BiacSimpleDMP-v0',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||||
|
kwargs={
|
||||||
|
"name": "alr_envs:HoleReacher-v0",
|
||||||
|
"num_dof": 5,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 2,
|
||||||
|
"learn_goal": True,
|
||||||
|
"alpha_phase": 2,
|
||||||
|
"bandwidth_factor": 2,
|
||||||
|
"policy_type": "velocity",
|
||||||
|
"weights_scale": 50,
|
||||||
|
"goal_scale": 0.1
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# BBO functions
|
# BBO functions
|
||||||
|
@ -2,40 +2,7 @@ import gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib import patches
|
from matplotlib import patches
|
||||||
|
from alr_envs.classic_control.utils import check_self_collision
|
||||||
from alr_envs import DmpWrapper
|
|
||||||
from alr_envs.utils.wrapper.detpmp_wrapper import DetPMPWrapper
|
|
||||||
|
|
||||||
|
|
||||||
def ccw(A, B, C):
|
|
||||||
return (C[1] - A[1]) * (B[0] - A[0]) - (B[1] - A[1]) * (C[0] - A[0]) > 1e-12
|
|
||||||
|
|
||||||
|
|
||||||
# Return true if line segments AB and CD intersect
|
|
||||||
def intersect(A, B, C, D):
|
|
||||||
return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D)
|
|
||||||
|
|
||||||
|
|
||||||
def holereacher_dmp(**kwargs):
|
|
||||||
_env = gym.make("alr_envs:HoleReacher-v0")
|
|
||||||
# _env = HoleReacher(**kwargs)
|
|
||||||
return DmpWrapper(_env, num_dof=5, num_basis=5, duration=2, dt=_env.dt, learn_goal=True, alpha_phase=3.5,
|
|
||||||
start_pos=_env.start_pos, policy_type="velocity", weights_scale=100, goal_scale=0.1)
|
|
||||||
|
|
||||||
|
|
||||||
def holereacher_fix_goal_dmp(**kwargs):
|
|
||||||
_env = gym.make("alr_envs:HoleReacher-v0")
|
|
||||||
# _env = HoleReacher(**kwargs)
|
|
||||||
return DmpWrapper(_env, num_dof=5, num_basis=5, duration=2, dt=_env.dt, learn_goal=False, alpha_phase=3.5,
|
|
||||||
start_pos=_env.start_pos, policy_type="velocity", weights_scale=50, goal_scale=1,
|
|
||||||
final_pos=np.array([2.02669572, -1.25966385, -1.51618198, -0.80946476, 0.02012344]))
|
|
||||||
|
|
||||||
|
|
||||||
def holereacher_detpmp(**kwargs):
|
|
||||||
_env = gym.make("alr_envs:HoleReacher-v0")
|
|
||||||
# _env = HoleReacher(**kwargs)
|
|
||||||
return DetPMPWrapper(_env, num_dof=5, num_basis=5, width=0.005, policy_type="velocity", start_pos=_env.start_pos,
|
|
||||||
duration=2, post_traj_time=0, dt=_env.dt, weights_scale=0.25, zero_start=True, zero_goal=False)
|
|
||||||
|
|
||||||
|
|
||||||
class HoleReacher(gym.Env):
|
class HoleReacher(gym.Env):
|
||||||
@ -166,7 +133,7 @@ class HoleReacher(gym.Env):
|
|||||||
wall_collision = False
|
wall_collision = False
|
||||||
|
|
||||||
if not self.allow_self_collision:
|
if not self.allow_self_collision:
|
||||||
self_collision = self.check_self_collision(line_points_in_taskspace)
|
self_collision = check_self_collision(line_points_in_taskspace)
|
||||||
if np.any(np.abs(self._joint_angles) > np.pi) and not self.allow_self_collision:
|
if np.any(np.abs(self._joint_angles) > np.pi) and not self.allow_self_collision:
|
||||||
self_collision = True
|
self_collision = True
|
||||||
|
|
||||||
@ -209,14 +176,6 @@ class HoleReacher(gym.Env):
|
|||||||
|
|
||||||
return np.squeeze(endeffector + self._joints[0, :])
|
return np.squeeze(endeffector + self._joints[0, :])
|
||||||
|
|
||||||
def check_self_collision(self, line_points):
|
|
||||||
for i, line1 in enumerate(line_points):
|
|
||||||
for line2 in line_points[i + 2:, :, :]:
|
|
||||||
# if line1 != line2:
|
|
||||||
if intersect(line1[0], line1[-1], line2[0], line2[-1]):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def check_wall_collision(self, line_points):
|
def check_wall_collision(self, line_points):
|
||||||
|
|
||||||
# all points that are before the hole in x
|
# all points that are before the hole in x
|
||||||
|
@ -1,155 +1,17 @@
|
|||||||
from alr_envs.classic_control.hole_reacher import HoleReacher
|
def ccw(A, B, C):
|
||||||
from alr_envs.classic_control.viapoint_reacher import ViaPointReacher
|
return (C[1] - A[1]) * (B[0] - A[0]) - (B[1] - A[1]) * (C[0] - A[0]) > 1e-12
|
||||||
from alr_envs.utils.wrapper.dmp_wrapper import DmpWrapper
|
|
||||||
from alr_envs.utils.wrapper.detpmp_wrapper import DetPMPWrapper
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def make_viapointreacher_env(rank, seed=0):
|
# Return true if line segments AB and CD intersect
|
||||||
"""
|
def intersect(A, B, C, D):
|
||||||
Utility function for multiprocessed env.
|
return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D)
|
||||||
|
|
||||||
:param env_id: (str) the environment ID
|
|
||||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
|
||||||
:param seed: (int) the initial seed for RNG
|
|
||||||
:param rank: (int) index of the subprocess
|
|
||||||
:returns a function that generates an environment
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _init():
|
|
||||||
_env = ViaPointReacher(n_links=5,
|
|
||||||
allow_self_collision=False,
|
|
||||||
collision_penalty=1000)
|
|
||||||
|
|
||||||
_env = DmpWrapper(_env,
|
|
||||||
num_dof=5,
|
|
||||||
num_basis=5,
|
|
||||||
duration=2,
|
|
||||||
alpha_phase=2.5,
|
|
||||||
dt=_env.dt,
|
|
||||||
start_pos=_env.start_pos,
|
|
||||||
learn_goal=False,
|
|
||||||
policy_type="velocity",
|
|
||||||
weights_scale=50)
|
|
||||||
_env.seed(seed + rank)
|
|
||||||
return _env
|
|
||||||
|
|
||||||
return _init
|
|
||||||
|
|
||||||
|
|
||||||
def make_holereacher_env(rank, seed=0):
|
def check_self_collision(line_points):
|
||||||
"""
|
for i, line1 in enumerate(line_points):
|
||||||
Utility function for multiprocessed env.
|
for line2 in line_points[i + 2:, :, :]:
|
||||||
|
# if line1 != line2:
|
||||||
|
if intersect(line1[0], line1[-1], line2[0], line2[-1]):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
:param env_id: (str) the environment ID
|
|
||||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
|
||||||
:param seed: (int) the initial seed for RNG
|
|
||||||
:param rank: (int) index of the subprocess
|
|
||||||
:returns a function that generates an environment
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _init():
|
|
||||||
_env = HoleReacher(n_links=5,
|
|
||||||
allow_self_collision=False,
|
|
||||||
allow_wall_collision=False,
|
|
||||||
hole_width=0.25,
|
|
||||||
hole_depth=1,
|
|
||||||
hole_x=2,
|
|
||||||
collision_penalty=100)
|
|
||||||
|
|
||||||
_env = DmpWrapper(_env,
|
|
||||||
num_dof=5,
|
|
||||||
num_basis=5,
|
|
||||||
duration=2,
|
|
||||||
dt=_env.dt,
|
|
||||||
learn_goal=True,
|
|
||||||
alpha_phase=2,
|
|
||||||
start_pos=_env.start_pos,
|
|
||||||
policy_type="velocity",
|
|
||||||
weights_scale=50,
|
|
||||||
goal_scale=0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
_env.seed(seed + rank)
|
|
||||||
return _env
|
|
||||||
|
|
||||||
return _init
|
|
||||||
|
|
||||||
|
|
||||||
def make_holereacher_fix_goal_env(rank, seed=0):
|
|
||||||
"""
|
|
||||||
Utility function for multiprocessed env.
|
|
||||||
|
|
||||||
:param env_id: (str) the environment ID
|
|
||||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
|
||||||
:param seed: (int) the initial seed for RNG
|
|
||||||
:param rank: (int) index of the subprocess
|
|
||||||
:returns a function that generates an environment
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _init():
|
|
||||||
_env = HoleReacher(n_links=5,
|
|
||||||
allow_self_collision=False,
|
|
||||||
allow_wall_collision=False,
|
|
||||||
hole_width=0.15,
|
|
||||||
hole_depth=1,
|
|
||||||
hole_x=1,
|
|
||||||
collision_penalty=100)
|
|
||||||
|
|
||||||
_env = DmpWrapper(_env,
|
|
||||||
num_dof=5,
|
|
||||||
num_basis=5,
|
|
||||||
duration=2,
|
|
||||||
dt=_env.dt,
|
|
||||||
learn_goal=False,
|
|
||||||
final_pos=np.array([2.02669572, -1.25966385, -1.51618198, -0.80946476, 0.02012344]),
|
|
||||||
alpha_phase=2,
|
|
||||||
start_pos=_env.start_pos,
|
|
||||||
policy_type="velocity",
|
|
||||||
weights_scale=50,
|
|
||||||
goal_scale=1
|
|
||||||
)
|
|
||||||
|
|
||||||
_env.seed(seed + rank)
|
|
||||||
return _env
|
|
||||||
|
|
||||||
return _init
|
|
||||||
|
|
||||||
|
|
||||||
def make_holereacher_env_pmp(rank, seed=0):
|
|
||||||
"""
|
|
||||||
Utility function for multiprocessed env.
|
|
||||||
|
|
||||||
:param env_id: (str) the environment ID
|
|
||||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
|
||||||
:param seed: (int) the initial seed for RNG
|
|
||||||
:param rank: (int) index of the subprocess
|
|
||||||
:returns a function that generates an environment
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _init():
|
|
||||||
_env = HoleReacher(n_links=5,
|
|
||||||
allow_self_collision=False,
|
|
||||||
allow_wall_collision=False,
|
|
||||||
hole_width=0.15,
|
|
||||||
hole_depth=1,
|
|
||||||
hole_x=1,
|
|
||||||
collision_penalty=1000)
|
|
||||||
|
|
||||||
_env = DetPMPWrapper(_env,
|
|
||||||
num_dof=5,
|
|
||||||
num_basis=5,
|
|
||||||
width=0.02,
|
|
||||||
policy_type="velocity",
|
|
||||||
start_pos=_env.start_pos,
|
|
||||||
duration=2,
|
|
||||||
post_traj_time=0,
|
|
||||||
dt=_env.dt,
|
|
||||||
weights_scale=0.2,
|
|
||||||
zero_start=True,
|
|
||||||
zero_goal=False
|
|
||||||
)
|
|
||||||
_env.seed(seed + rank)
|
|
||||||
return _env
|
|
||||||
|
|
||||||
return _init
|
|
||||||
|
@ -2,15 +2,7 @@ import gym
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from alr_envs import DmpWrapper
|
from alr_envs.classic_control.utils import check_self_collision
|
||||||
from alr_envs.utils.utils import check_self_collision
|
|
||||||
|
|
||||||
|
|
||||||
def viapoint_dmp(**kwargs):
|
|
||||||
_env = gym.make("alr_envs:ViaPointReacher-v0")
|
|
||||||
# _env = ViaPointReacher(**kwargs)
|
|
||||||
return DmpWrapper(_env, num_dof=5, num_basis=5, duration=2, alpha_phase=2.5, dt=_env.dt,
|
|
||||||
start_pos=_env.start_pos, learn_goal=False, policy_type="velocity", weights_scale=50)
|
|
||||||
|
|
||||||
|
|
||||||
class ViaPointReacher(gym.Env):
|
class ViaPointReacher(gym.Env):
|
||||||
|
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
from gym import utils
|
from gym import utils
|
||||||
from gym.envs.mujoco import mujoco_env
|
from gym.envs.mujoco import mujoco_env
|
||||||
|
|
||||||
from alr_envs.utils.utils import angle_normalize
|
import alr_envs.utils.utils as alr_utils
|
||||||
|
|
||||||
|
|
||||||
class BalancingEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
class BalancingEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||||
@ -23,7 +23,7 @@ class BalancingEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
mujoco_env.MujocoEnv.__init__(self, os.path.join(os.path.dirname(__file__), "assets", file_name), 2)
|
mujoco_env.MujocoEnv.__init__(self, os.path.join(os.path.dirname(__file__), "assets", file_name), 2)
|
||||||
|
|
||||||
def step(self, a):
|
def step(self, a):
|
||||||
angle = angle_normalize(np.sum(self.sim.data.qpos.flat[:self.n_links]), type="rad")
|
angle = alr_utils.angle_normalize(np.sum(self.sim.data.qpos.flat[:self.n_links]), type="rad")
|
||||||
reward = - np.abs(angle)
|
reward = - np.abs(angle)
|
||||||
|
|
||||||
self.do_simulation(a, self.frame_skip)
|
self.do_simulation(a, self.frame_skip)
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_simple import ALRBallInACupEnv
|
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
from gym import utils
|
from gym import utils
|
||||||
from gym.envs.mujoco import mujoco_env
|
from gym.envs.mujoco import mujoco_env
|
||||||
|
|
||||||
from alr_envs.utils.utils import angle_normalize
|
import alr_envs.utils.utils as alr_utils
|
||||||
|
|
||||||
|
|
||||||
class ALRReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
class ALRReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||||
@ -47,7 +47,7 @@ class ALRReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
if self.balance:
|
if self.balance:
|
||||||
reward_balance -= self.balance_weight * np.abs(
|
reward_balance -= self.balance_weight * np.abs(
|
||||||
angle_normalize(np.sum(self.sim.data.qpos.flat[:self.n_links]), type="rad"))
|
alr_utils.angle_normalize(np.sum(self.sim.data.qpos.flat[:self.n_links]), type="rad"))
|
||||||
|
|
||||||
reward = reward_dist + reward_ctrl + angular_vel + reward_balance
|
reward = reward_dist + reward_ctrl + angular_vel + reward_balance
|
||||||
self.do_simulation(a, self.frame_skip)
|
self.do_simulation(a, self.frame_skip)
|
||||||
|
0
alr_envs/utils/legacy/__init__.py
Normal file
0
alr_envs/utils/legacy/__init__.py
Normal file
88
alr_envs/utils/legacy/detpmp_env_wrapper.py
Normal file
88
alr_envs/utils/legacy/detpmp_env_wrapper.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from alr_envs.utils.policies import get_policy_class
|
||||||
|
from mp_lib import det_promp
|
||||||
|
import numpy as np
|
||||||
|
import gym
|
||||||
|
|
||||||
|
|
||||||
|
class DetPMPEnvWrapper(gym.Wrapper):
|
||||||
|
def __init__(self,
|
||||||
|
env,
|
||||||
|
num_dof,
|
||||||
|
num_basis,
|
||||||
|
width,
|
||||||
|
off=0.01,
|
||||||
|
start_pos=None,
|
||||||
|
duration=1,
|
||||||
|
dt=0.01,
|
||||||
|
post_traj_time=0.,
|
||||||
|
policy_type=None,
|
||||||
|
weights_scale=1,
|
||||||
|
zero_start=False,
|
||||||
|
zero_goal=False,
|
||||||
|
):
|
||||||
|
super(DetPMPEnvWrapper, self).__init__(env)
|
||||||
|
self.num_dof = num_dof
|
||||||
|
self.num_basis = num_basis
|
||||||
|
self.dim = num_dof * num_basis
|
||||||
|
self.pmp = det_promp.DeterministicProMP(n_basis=num_basis, n_dof=num_dof, width=width, off=off,
|
||||||
|
zero_start=zero_start, zero_goal=zero_goal)
|
||||||
|
weights = np.zeros(shape=(num_basis, num_dof))
|
||||||
|
self.pmp.set_weights(duration, weights)
|
||||||
|
self.weights_scale = weights_scale
|
||||||
|
|
||||||
|
self.duration = duration
|
||||||
|
self.dt = dt
|
||||||
|
self.post_traj_steps = int(post_traj_time / dt)
|
||||||
|
|
||||||
|
self.start_pos = start_pos
|
||||||
|
self.zero_start = zero_start
|
||||||
|
|
||||||
|
policy_class = get_policy_class(policy_type)
|
||||||
|
self.policy = policy_class(env)
|
||||||
|
|
||||||
|
def __call__(self, params, contexts=None):
|
||||||
|
params = np.atleast_2d(params)
|
||||||
|
rewards = []
|
||||||
|
infos = []
|
||||||
|
for p, c in zip(params, contexts):
|
||||||
|
reward, info = self.rollout(p, c)
|
||||||
|
rewards.append(reward)
|
||||||
|
infos.append(info)
|
||||||
|
|
||||||
|
return np.array(rewards), infos
|
||||||
|
|
||||||
|
def rollout(self, params, context=None, render=False):
|
||||||
|
""" This function generates a trajectory based on a DMP and then does the usual loop over reset and step"""
|
||||||
|
params = np.reshape(params, newshape=(self.num_basis, self.num_dof)) * self.weights_scale
|
||||||
|
self.pmp.set_weights(self.duration, params)
|
||||||
|
t, des_pos, des_vel, des_acc = self.pmp.compute_trajectory(1 / self.dt, 1.)
|
||||||
|
if self.zero_start:
|
||||||
|
des_pos += self.start_pos[None, :]
|
||||||
|
|
||||||
|
if self.post_traj_steps > 0:
|
||||||
|
des_pos = np.vstack([des_pos, np.tile(des_pos[-1, :], [self.post_traj_steps, 1])])
|
||||||
|
des_vel = np.vstack([des_vel, np.zeros(shape=(self.post_traj_steps, self.num_dof))])
|
||||||
|
|
||||||
|
self._trajectory = des_pos
|
||||||
|
self._velocity = des_vel
|
||||||
|
|
||||||
|
rews = []
|
||||||
|
infos = []
|
||||||
|
|
||||||
|
self.env.configure(context)
|
||||||
|
self.env.reset()
|
||||||
|
|
||||||
|
for t, pos_vel in enumerate(zip(des_pos, des_vel)):
|
||||||
|
ac = self.policy.get_action(pos_vel[0], pos_vel[1])
|
||||||
|
obs, rew, done, info = self.env.step(ac)
|
||||||
|
rews.append(rew)
|
||||||
|
infos.append(info)
|
||||||
|
if render:
|
||||||
|
self.env.render(mode="human")
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
reward = np.sum(rews)
|
||||||
|
|
||||||
|
return reward, info
|
||||||
|
|
125
alr_envs/utils/legacy/dmp_env_wrapper.py
Normal file
125
alr_envs/utils/legacy/dmp_env_wrapper.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
from alr_envs.utils.policies import get_policy_class
|
||||||
|
from mp_lib.phase import ExpDecayPhaseGenerator
|
||||||
|
from mp_lib.basis import DMPBasisGenerator
|
||||||
|
from mp_lib import dmps
|
||||||
|
import numpy as np
|
||||||
|
import gym
|
||||||
|
|
||||||
|
|
||||||
|
class DmpEnvWrapper(gym.Wrapper):
|
||||||
|
def __init__(self,
|
||||||
|
env,
|
||||||
|
num_dof,
|
||||||
|
num_basis,
|
||||||
|
start_pos=None,
|
||||||
|
final_pos=None,
|
||||||
|
duration=1,
|
||||||
|
dt=0.01,
|
||||||
|
alpha_phase=2,
|
||||||
|
bandwidth_factor=3,
|
||||||
|
learn_goal=False,
|
||||||
|
post_traj_time=0.,
|
||||||
|
policy_type=None,
|
||||||
|
weights_scale=1.,
|
||||||
|
goal_scale=1.,
|
||||||
|
):
|
||||||
|
super(DmpEnvWrapper, self).__init__(env)
|
||||||
|
self.num_dof = num_dof
|
||||||
|
self.num_basis = num_basis
|
||||||
|
self.dim = num_dof * num_basis
|
||||||
|
if learn_goal:
|
||||||
|
self.dim += num_dof
|
||||||
|
self.learn_goal = learn_goal
|
||||||
|
self.duration = duration # seconds
|
||||||
|
time_steps = int(duration / dt)
|
||||||
|
self.t = np.linspace(0, duration, time_steps)
|
||||||
|
self.post_traj_steps = int(post_traj_time / dt)
|
||||||
|
|
||||||
|
phase_generator = ExpDecayPhaseGenerator(alpha_phase=alpha_phase, duration=duration)
|
||||||
|
basis_generator = DMPBasisGenerator(phase_generator,
|
||||||
|
duration=duration,
|
||||||
|
num_basis=self.num_basis,
|
||||||
|
basis_bandwidth_factor=bandwidth_factor)
|
||||||
|
|
||||||
|
self.dmp = dmps.DMP(num_dof=num_dof,
|
||||||
|
basis_generator=basis_generator,
|
||||||
|
phase_generator=phase_generator,
|
||||||
|
num_time_steps=time_steps,
|
||||||
|
dt=dt
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dmp.dmp_start_pos = start_pos.reshape((1, num_dof))
|
||||||
|
|
||||||
|
dmp_weights = np.zeros((num_basis, num_dof))
|
||||||
|
if learn_goal:
|
||||||
|
dmp_goal_pos = np.zeros(num_dof)
|
||||||
|
else:
|
||||||
|
dmp_goal_pos = final_pos
|
||||||
|
|
||||||
|
self.dmp.set_weights(dmp_weights, dmp_goal_pos)
|
||||||
|
self.weights_scale = weights_scale
|
||||||
|
self.goal_scale = goal_scale
|
||||||
|
|
||||||
|
policy_class = get_policy_class(policy_type)
|
||||||
|
self.policy = policy_class(env)
|
||||||
|
|
||||||
|
def __call__(self, params, contexts=None):
|
||||||
|
params = np.atleast_2d(params)
|
||||||
|
rewards = []
|
||||||
|
infos = []
|
||||||
|
for p, c in zip(params, contexts):
|
||||||
|
reward, info = self.rollout(p, c)
|
||||||
|
rewards.append(reward)
|
||||||
|
infos.append(info)
|
||||||
|
|
||||||
|
return np.array(rewards), infos
|
||||||
|
|
||||||
|
def goal_and_weights(self, params):
|
||||||
|
if len(params.shape) > 1:
|
||||||
|
assert params.shape[1] == self.dim
|
||||||
|
else:
|
||||||
|
assert len(params) == self.dim
|
||||||
|
params = np.reshape(params, [1, self.dim])
|
||||||
|
|
||||||
|
if self.learn_goal:
|
||||||
|
goal_pos = params[0, -self.num_dof:]
|
||||||
|
weight_matrix = np.reshape(params[:, :-self.num_dof], [self.num_basis, self.num_dof])
|
||||||
|
else:
|
||||||
|
goal_pos = self.dmp.dmp_goal_pos.flatten()
|
||||||
|
assert goal_pos is not None
|
||||||
|
weight_matrix = np.reshape(params, [self.num_basis, self.num_dof])
|
||||||
|
|
||||||
|
return goal_pos * self.goal_scale, weight_matrix * self.weights_scale
|
||||||
|
|
||||||
|
def rollout(self, params, context=None, render=False):
|
||||||
|
""" This function generates a trajectory based on a DMP and then does the usual loop over reset and step"""
|
||||||
|
goal_pos, weight_matrix = self.goal_and_weights(params)
|
||||||
|
self.dmp.set_weights(weight_matrix, goal_pos)
|
||||||
|
trajectory, velocity = self.dmp.reference_trajectory(self.t)
|
||||||
|
|
||||||
|
if self.post_traj_steps > 0:
|
||||||
|
trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])])
|
||||||
|
velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.num_dof))])
|
||||||
|
|
||||||
|
self._trajectory = trajectory
|
||||||
|
self._velocity = velocity
|
||||||
|
|
||||||
|
rews = []
|
||||||
|
infos = []
|
||||||
|
|
||||||
|
self.env.configure(context)
|
||||||
|
self.env.reset()
|
||||||
|
|
||||||
|
for t, pos_vel in enumerate(zip(trajectory, velocity)):
|
||||||
|
ac = self.policy.get_action(pos_vel[0], pos_vel[1])
|
||||||
|
obs, rew, done, info = self.env.step(ac)
|
||||||
|
rews.append(rew)
|
||||||
|
infos.append(info)
|
||||||
|
if render:
|
||||||
|
self.env.render(mode="human")
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
reward = np.sum(rews)
|
||||||
|
|
||||||
|
return reward, info
|
@ -1,6 +1,4 @@
|
|||||||
from alr_envs.classic_control.utils import make_viapointreacher_env
|
from alr_envs.utils.legacy.utils import make_holereacher_env
|
||||||
from alr_envs.classic_control.utils import make_holereacher_env, make_holereacher_fix_goal_env
|
|
||||||
from alr_envs.utils.dmp_async_vec_env import DmpAsyncVectorEnv
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
@ -1,8 +1,5 @@
|
|||||||
from alr_envs.mujoco.ball_in_a_cup.utils import make_env, make_simple_env, make_simple_dmp_env
|
from alr_envs.mujoco.ball_in_a_cup.utils import make_simple_dmp_env
|
||||||
from alr_envs.utils.dmp_async_vec_env import DmpAsyncVectorEnv
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import wrappers
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
156
alr_envs/utils/legacy/utils.py
Normal file
156
alr_envs/utils/legacy/utils.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
import alr_envs.classic_control.hole_reacher as hr
|
||||||
|
import alr_envs.classic_control.viapoint_reacher as vpr
|
||||||
|
from alr_envs.utils.wrapper.dmp_wrapper import DmpWrapper
|
||||||
|
from alr_envs.utils.wrapper.detpmp_wrapper import DetPMPWrapper
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def make_viapointreacher_env(rank, seed=0):
|
||||||
|
"""
|
||||||
|
Utility function for multiprocessed env.
|
||||||
|
|
||||||
|
:param env_id: (str) the environment ID
|
||||||
|
:param num_env: (int) the number of environments you wish to have in subprocesses
|
||||||
|
:param seed: (int) the initial seed for RNG
|
||||||
|
:param rank: (int) index of the subprocess
|
||||||
|
:returns a function that generates an environment
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _init():
|
||||||
|
_env = vpr.ViaPointReacher(n_links=5,
|
||||||
|
allow_self_collision=False,
|
||||||
|
collision_penalty=1000)
|
||||||
|
|
||||||
|
_env = DmpWrapper(_env,
|
||||||
|
num_dof=5,
|
||||||
|
num_basis=5,
|
||||||
|
duration=2,
|
||||||
|
alpha_phase=2.5,
|
||||||
|
dt=_env.dt,
|
||||||
|
start_pos=_env.start_pos,
|
||||||
|
learn_goal=False,
|
||||||
|
policy_type="velocity",
|
||||||
|
weights_scale=50)
|
||||||
|
_env.seed(seed + rank)
|
||||||
|
return _env
|
||||||
|
|
||||||
|
return _init
|
||||||
|
|
||||||
|
|
||||||
|
def make_holereacher_env(rank, seed=0):
|
||||||
|
"""
|
||||||
|
Utility function for multiprocessed env.
|
||||||
|
|
||||||
|
:param env_id: (str) the environment ID
|
||||||
|
:param num_env: (int) the number of environments you wish to have in subprocesses
|
||||||
|
:param seed: (int) the initial seed for RNG
|
||||||
|
:param rank: (int) index of the subprocess
|
||||||
|
:returns a function that generates an environment
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _init():
|
||||||
|
_env = hr.HoleReacher(n_links=5,
|
||||||
|
allow_self_collision=False,
|
||||||
|
allow_wall_collision=False,
|
||||||
|
hole_width=0.25,
|
||||||
|
hole_depth=1,
|
||||||
|
hole_x=2,
|
||||||
|
collision_penalty=100)
|
||||||
|
|
||||||
|
_env = DmpWrapper(_env,
|
||||||
|
num_dof=5,
|
||||||
|
num_basis=5,
|
||||||
|
duration=2,
|
||||||
|
bandwidth_factor=2,
|
||||||
|
dt=_env.dt,
|
||||||
|
learn_goal=True,
|
||||||
|
alpha_phase=2,
|
||||||
|
start_pos=_env.start_pos,
|
||||||
|
policy_type="velocity",
|
||||||
|
weights_scale=50,
|
||||||
|
goal_scale=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
_env.seed(seed + rank)
|
||||||
|
return _env
|
||||||
|
|
||||||
|
return _init
|
||||||
|
|
||||||
|
|
||||||
|
def make_holereacher_fix_goal_env(rank, seed=0):
|
||||||
|
"""
|
||||||
|
Utility function for multiprocessed env.
|
||||||
|
|
||||||
|
:param env_id: (str) the environment ID
|
||||||
|
:param num_env: (int) the number of environments you wish to have in subprocesses
|
||||||
|
:param seed: (int) the initial seed for RNG
|
||||||
|
:param rank: (int) index of the subprocess
|
||||||
|
:returns a function that generates an environment
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _init():
|
||||||
|
_env = hr.HoleReacher(n_links=5,
|
||||||
|
allow_self_collision=False,
|
||||||
|
allow_wall_collision=False,
|
||||||
|
hole_width=0.15,
|
||||||
|
hole_depth=1,
|
||||||
|
hole_x=1,
|
||||||
|
collision_penalty=100)
|
||||||
|
|
||||||
|
_env = DmpWrapper(_env,
|
||||||
|
num_dof=5,
|
||||||
|
num_basis=5,
|
||||||
|
duration=2,
|
||||||
|
dt=_env.dt,
|
||||||
|
learn_goal=False,
|
||||||
|
final_pos=np.array([2.02669572, -1.25966385, -1.51618198, -0.80946476, 0.02012344]),
|
||||||
|
alpha_phase=2,
|
||||||
|
start_pos=_env.start_pos,
|
||||||
|
policy_type="velocity",
|
||||||
|
weights_scale=50,
|
||||||
|
goal_scale=1
|
||||||
|
)
|
||||||
|
|
||||||
|
_env.seed(seed + rank)
|
||||||
|
return _env
|
||||||
|
|
||||||
|
return _init
|
||||||
|
|
||||||
|
|
||||||
|
def make_holereacher_env_pmp(rank, seed=0):
|
||||||
|
"""
|
||||||
|
Utility function for multiprocessed env.
|
||||||
|
|
||||||
|
:param env_id: (str) the environment ID
|
||||||
|
:param num_env: (int) the number of environments you wish to have in subprocesses
|
||||||
|
:param seed: (int) the initial seed for RNG
|
||||||
|
:param rank: (int) index of the subprocess
|
||||||
|
:returns a function that generates an environment
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _init():
|
||||||
|
_env = hr.HoleReacher(n_links=5,
|
||||||
|
allow_self_collision=False,
|
||||||
|
allow_wall_collision=False,
|
||||||
|
hole_width=0.15,
|
||||||
|
hole_depth=1,
|
||||||
|
hole_x=1,
|
||||||
|
collision_penalty=1000)
|
||||||
|
|
||||||
|
_env = DetPMPWrapper(_env,
|
||||||
|
num_dof=5,
|
||||||
|
num_basis=5,
|
||||||
|
width=0.02,
|
||||||
|
policy_type="velocity",
|
||||||
|
start_pos=_env.start_pos,
|
||||||
|
duration=2,
|
||||||
|
post_traj_time=0,
|
||||||
|
dt=_env.dt,
|
||||||
|
weights_scale=0.2,
|
||||||
|
zero_start=True,
|
||||||
|
zero_goal=False
|
||||||
|
)
|
||||||
|
_env.seed(seed + rank)
|
||||||
|
return _env
|
||||||
|
|
||||||
|
return _init
|
137
alr_envs/utils/make_env_helpers.py
Normal file
137
alr_envs/utils/make_env_helpers.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
from alr_envs.utils.wrapper.dmp_wrapper import DmpWrapper
|
||||||
|
from alr_envs.utils.wrapper.detpmp_wrapper import DetPMPWrapper
|
||||||
|
import gym
|
||||||
|
from gym.vector.utils import write_to_shared_memory
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def make_env(env_id, seed, rank):
|
||||||
|
env = gym.make(env_id)
|
||||||
|
env.seed(seed + rank)
|
||||||
|
return lambda: env
|
||||||
|
|
||||||
|
|
||||||
|
def make_contextual_env(env_id, context, seed, rank):
|
||||||
|
env = gym.make(env_id, context=context)
|
||||||
|
env.seed(seed + rank)
|
||||||
|
return lambda: env
|
||||||
|
|
||||||
|
|
||||||
|
def make_dmp_env(**kwargs):
|
||||||
|
name = kwargs.pop("name")
|
||||||
|
_env = gym.make(name)
|
||||||
|
return DmpWrapper(_env, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def make_detpmp_env(**kwargs):
|
||||||
|
name = kwargs.pop("name")
|
||||||
|
_env = gym.make(name)
|
||||||
|
return DetPMPWrapper(_env, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||||
|
# assert shared_memory is None
|
||||||
|
# env = env_fn()
|
||||||
|
# parent_pipe.close()
|
||||||
|
# try:
|
||||||
|
# while True:
|
||||||
|
# command, data = pipe.recv()
|
||||||
|
# if command == 'reset':
|
||||||
|
# observation = env.reset()
|
||||||
|
# pipe.send((observation, True))
|
||||||
|
# elif command == 'configure':
|
||||||
|
# env.configure(data)
|
||||||
|
# pipe.send((None, True))
|
||||||
|
# elif command == 'step':
|
||||||
|
# observation, reward, done, info = env.step(data)
|
||||||
|
# if done:
|
||||||
|
# observation = env.reset()
|
||||||
|
# pipe.send(((observation, reward, done, info), True))
|
||||||
|
# elif command == 'seed':
|
||||||
|
# env.seed(data)
|
||||||
|
# pipe.send((None, True))
|
||||||
|
# elif command == 'close':
|
||||||
|
# pipe.send((None, True))
|
||||||
|
# break
|
||||||
|
# elif command == '_check_observation_space':
|
||||||
|
# pipe.send((data == env.observation_space, True))
|
||||||
|
# else:
|
||||||
|
# raise RuntimeError('Received unknown command `{0}`. Must '
|
||||||
|
# 'be one of {`reset`, `step`, `seed`, `close`, '
|
||||||
|
# '`_check_observation_space`}.'.format(command))
|
||||||
|
# except (KeyboardInterrupt, Exception):
|
||||||
|
# error_queue.put((index,) + sys.exc_info()[:2])
|
||||||
|
# pipe.send((None, False))
|
||||||
|
# finally:
|
||||||
|
# env.close()
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||||
|
# assert shared_memory is not None
|
||||||
|
# env = env_fn()
|
||||||
|
# observation_space = env.observation_space
|
||||||
|
# parent_pipe.close()
|
||||||
|
# try:
|
||||||
|
# while True:
|
||||||
|
# command, data = pipe.recv()
|
||||||
|
# if command == 'reset':
|
||||||
|
# observation = env.reset()
|
||||||
|
# write_to_shared_memory(index, observation, shared_memory,
|
||||||
|
# observation_space)
|
||||||
|
# pipe.send((None, True))
|
||||||
|
# elif command == 'configure':
|
||||||
|
# env.configure(data)
|
||||||
|
# pipe.send((None, True))
|
||||||
|
# elif command == 'step':
|
||||||
|
# observation, reward, done, info = env.step(data)
|
||||||
|
# if done:
|
||||||
|
# observation = env.reset()
|
||||||
|
# write_to_shared_memory(index, observation, shared_memory,
|
||||||
|
# observation_space)
|
||||||
|
# pipe.send(((None, reward, done, info), True))
|
||||||
|
# elif command == 'seed':
|
||||||
|
# env.seed(data)
|
||||||
|
# pipe.send((None, True))
|
||||||
|
# elif command == 'close':
|
||||||
|
# pipe.send((None, True))
|
||||||
|
# break
|
||||||
|
# elif command == '_check_observation_space':
|
||||||
|
# pipe.send((data == observation_space, True))
|
||||||
|
# else:
|
||||||
|
# raise RuntimeError('Received unknown command `{0}`. Must '
|
||||||
|
# 'be one of {`reset`, `step`, `seed`, `close`, '
|
||||||
|
# '`_check_observation_space`}.'.format(command))
|
||||||
|
# except (KeyboardInterrupt, Exception):
|
||||||
|
# error_queue.put((index,) + sys.exc_info()[:2])
|
||||||
|
# pipe.send((None, False))
|
||||||
|
# finally:
|
||||||
|
# env.close()
|
||||||
|
|
||||||
|
|
||||||
|
# def viapoint_dmp(**kwargs):
|
||||||
|
# _env = gym.make("alr_envs:ViaPointReacher-v0")
|
||||||
|
# # _env = ViaPointReacher(**kwargs)
|
||||||
|
# return DmpWrapper(_env, num_dof=5, num_basis=5, duration=2, alpha_phase=2.5, dt=_env.dt,
|
||||||
|
# start_pos=_env.start_pos, learn_goal=False, policy_type="velocity", weights_scale=50)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def holereacher_dmp(**kwargs):
|
||||||
|
# _env = gym.make("alr_envs:HoleReacher-v0")
|
||||||
|
# # _env = HoleReacher(**kwargs)
|
||||||
|
# return DmpWrapper(_env, num_dof=5, num_basis=5, duration=2, dt=_env.dt, learn_goal=True, alpha_phase=2,
|
||||||
|
# start_pos=_env.start_pos, policy_type="velocity", weights_scale=50, goal_scale=0.1)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def holereacher_fix_goal_dmp(**kwargs):
|
||||||
|
# _env = gym.make("alr_envs:HoleReacher-v0")
|
||||||
|
# # _env = HoleReacher(**kwargs)
|
||||||
|
# return DmpWrapper(_env, num_dof=5, num_basis=5, duration=2, dt=_env.dt, learn_goal=False, alpha_phase=2,
|
||||||
|
# start_pos=_env.start_pos, policy_type="velocity", weights_scale=50, goal_scale=1,
|
||||||
|
# final_pos=np.array([2.02669572, -1.25966385, -1.51618198, -0.80946476, 0.02012344]))
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def holereacher_detpmp(**kwargs):
|
||||||
|
# _env = gym.make("alr_envs:HoleReacher-v0")
|
||||||
|
# # _env = HoleReacher(**kwargs)
|
||||||
|
# return DetPMPWrapper(_env, num_dof=5, num_basis=5, width=0.005, policy_type="velocity", start_pos=_env.start_pos,
|
||||||
|
# duration=2, post_traj_time=0, dt=_env.dt, weights_scale=0.25, zero_start=True, zero_goal=False)
|
@ -20,30 +20,3 @@ def angle_normalize(x, type="deg"):
|
|||||||
two_pi = 2 * np.pi
|
two_pi = 2 * np.pi
|
||||||
return x - two_pi * np.floor((x + np.pi) / two_pi)
|
return x - two_pi * np.floor((x + np.pi) / two_pi)
|
||||||
|
|
||||||
|
|
||||||
def ccw(A, B, C):
|
|
||||||
return (C[1] - A[1]) * (B[0] - A[0]) - (B[1] - A[1]) * (C[0] - A[0]) > 1e-12
|
|
||||||
|
|
||||||
|
|
||||||
def intersect(A, B, C, D):
|
|
||||||
"""
|
|
||||||
Return true if line segments AB and CD intersects
|
|
||||||
Args:
|
|
||||||
A: start point line one
|
|
||||||
B: end point line one
|
|
||||||
C: start point line two
|
|
||||||
D: end point line two
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
"""
|
|
||||||
return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D)
|
|
||||||
|
|
||||||
|
|
||||||
def check_self_collision(line_points):
|
|
||||||
for i, line1 in enumerate(line_points):
|
|
||||||
for line2 in line_points[i + 2:, :, :]:
|
|
||||||
# if line1 != line2:
|
|
||||||
if intersect(line1[0], line1[-1], line2[0], line2[-1]):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from alr_envs.utils.policies import get_policy_class
|
|
||||||
from mp_lib.phase import ExpDecayPhaseGenerator
|
from mp_lib.phase import ExpDecayPhaseGenerator
|
||||||
from mp_lib.basis import DMPBasisGenerator
|
from mp_lib.basis import DMPBasisGenerator
|
||||||
from mp_lib import dmps
|
from mp_lib import dmps
|
||||||
@ -11,9 +10,9 @@ from alr_envs.utils.wrapper.mp_wrapper import MPWrapper
|
|||||||
class DmpWrapper(MPWrapper):
|
class DmpWrapper(MPWrapper):
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, num_dof: int, num_basis: int, start_pos: np.ndarray = None,
|
def __init__(self, env: gym.Env, num_dof: int, num_basis: int, start_pos: np.ndarray = None,
|
||||||
final_pos: np.ndarray = None, duration: int = 1, alpha_phase: float = 2., dt: float = 0.01,
|
final_pos: np.ndarray = None, duration: int = 1, alpha_phase: float = 2., dt: float = None,
|
||||||
learn_goal: bool = False, post_traj_time: float = 0., policy_type: str = None,
|
learn_goal: bool = False, post_traj_time: float = 0., policy_type: str = None,
|
||||||
weights_scale: float = 1., goal_scale: float = 1.):
|
weights_scale: float = 1., goal_scale: float = 1., bandwidth_factor: float = 3.):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This Wrapper generates a trajectory based on a DMP and will only return episodic performances.
|
This Wrapper generates a trajectory based on a DMP and will only return episodic performances.
|
||||||
@ -33,20 +32,26 @@ class DmpWrapper(MPWrapper):
|
|||||||
goal_scale:
|
goal_scale:
|
||||||
"""
|
"""
|
||||||
self.learn_goal = learn_goal
|
self.learn_goal = learn_goal
|
||||||
|
dt = env.dt if hasattr(env, "dt") else dt
|
||||||
|
assert dt is not None
|
||||||
|
start_pos = env.start_pos if hasattr(env, "start_pos") else start_pos
|
||||||
|
assert start_pos is not None
|
||||||
self.t = np.linspace(0, duration, int(duration / dt))
|
self.t = np.linspace(0, duration, int(duration / dt))
|
||||||
self.goal_scale = goal_scale
|
self.goal_scale = goal_scale
|
||||||
|
|
||||||
super().__init__(env, num_dof, duration, dt, post_traj_time, policy_type, weights_scale,
|
super().__init__(env, num_dof, duration, dt, post_traj_time, policy_type, weights_scale,
|
||||||
num_basis=num_basis, start_pos=start_pos, final_pos=final_pos, alpha_phase=alpha_phase)
|
num_basis=num_basis, start_pos=start_pos, final_pos=final_pos, alpha_phase=alpha_phase,
|
||||||
|
bandwidth_factor=bandwidth_factor)
|
||||||
|
|
||||||
action_bounds = np.inf * np.ones((np.prod(self.mp.dmp_weights.shape) + (num_dof if learn_goal else 0)))
|
action_bounds = np.inf * np.ones((np.prod(self.mp.dmp_weights.shape) + (num_dof if learn_goal else 0)))
|
||||||
self.action_space = gym.spaces.Box(low=-action_bounds, high=action_bounds, dtype=np.float32)
|
self.action_space = gym.spaces.Box(low=-action_bounds, high=action_bounds, dtype=np.float32)
|
||||||
|
|
||||||
def initialize_mp(self, num_dof: int, duration: int, dt: float, num_basis: int = 5, start_pos: np.ndarray = None,
|
def initialize_mp(self, num_dof: int, duration: int, dt: float, num_basis: int = 5, start_pos: np.ndarray = None,
|
||||||
final_pos: np.ndarray = None, alpha_phase: float = 2.):
|
final_pos: np.ndarray = None, alpha_phase: float = 2., bandwidth_factor: float = 3.):
|
||||||
|
|
||||||
phase_generator = ExpDecayPhaseGenerator(alpha_phase=alpha_phase, duration=duration)
|
phase_generator = ExpDecayPhaseGenerator(alpha_phase=alpha_phase, duration=duration)
|
||||||
basis_generator = DMPBasisGenerator(phase_generator, duration=duration, num_basis=num_basis)
|
basis_generator = DMPBasisGenerator(phase_generator, duration=duration, num_basis=num_basis,
|
||||||
|
basis_bandwidth_factor=bandwidth_factor)
|
||||||
|
|
||||||
dmp = dmps.DMP(num_dof=num_dof, basis_generator=basis_generator, phase_generator=phase_generator,
|
dmp = dmps.DMP(num_dof=num_dof, basis_generator=basis_generator, phase_generator=phase_generator,
|
||||||
num_time_steps=int(duration / dt), dt=dt)
|
num_time_steps=int(duration / dt), dt=dt)
|
||||||
|
@ -13,19 +13,20 @@ class MPWrapper(gym.Wrapper, ABC):
|
|||||||
env: gym.Env,
|
env: gym.Env,
|
||||||
num_dof: int,
|
num_dof: int,
|
||||||
duration: int = 1,
|
duration: int = 1,
|
||||||
dt: float = 0.01,
|
dt: float = None,
|
||||||
# learn_goal: bool = False,
|
|
||||||
post_traj_time: float = 0.,
|
post_traj_time: float = 0.,
|
||||||
policy_type: str = None,
|
policy_type: str = None,
|
||||||
weights_scale: float = 1.,
|
weights_scale: float = 1.,
|
||||||
**mp_kwargs
|
**mp_kwargs
|
||||||
|
|
||||||
):
|
):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
# self.num_dof = num_dof
|
# self.num_dof = num_dof
|
||||||
# self.num_basis = num_basis
|
# self.num_basis = num_basis
|
||||||
# self.duration = duration # seconds
|
# self.duration = duration # seconds
|
||||||
|
|
||||||
|
# dt = env.dt if hasattr(env, "dt") else dt
|
||||||
|
assert dt is not None # this should never happen as MPWrapper is a base class
|
||||||
self.post_traj_steps = int(post_traj_time / dt)
|
self.post_traj_steps = int(post_traj_time / dt)
|
||||||
|
|
||||||
self.mp = self.initialize_mp(num_dof, duration, dt, **mp_kwargs)
|
self.mp = self.initialize_mp(num_dof, duration, dt, **mp_kwargs)
|
||||||
@ -38,6 +39,26 @@ class MPWrapper(gym.Wrapper, ABC):
|
|||||||
self.render_mode = None
|
self.render_mode = None
|
||||||
self.render_kwargs = None
|
self.render_kwargs = None
|
||||||
|
|
||||||
|
# TODO: not yet final
|
||||||
|
def __call__(self, params, contexts=None):
|
||||||
|
params = np.atleast_2d(params)
|
||||||
|
obs = []
|
||||||
|
rewards = []
|
||||||
|
dones = []
|
||||||
|
infos = []
|
||||||
|
for p, c in zip(params, contexts):
|
||||||
|
self.configure(c)
|
||||||
|
ob, reward, done, info = self.step(p)
|
||||||
|
obs.append(ob)
|
||||||
|
rewards.append(reward)
|
||||||
|
dones.append(done)
|
||||||
|
infos.append(info)
|
||||||
|
|
||||||
|
return obs, np.array(rewards), dones, infos
|
||||||
|
|
||||||
|
def configure(self, context):
|
||||||
|
self.env.configure(context)
|
||||||
|
|
||||||
def step(self, action: np.ndarray):
|
def step(self, action: np.ndarray):
|
||||||
""" This function generates a trajectory based on a DMP and then does the usual loop over reset and step"""
|
""" This function generates a trajectory based on a DMP and then does the usual loop over reset and step"""
|
||||||
trajectory, velocity = self.mp_rollout(action)
|
trajectory, velocity = self.mp_rollout(action)
|
||||||
@ -53,6 +74,7 @@ class MPWrapper(gym.Wrapper, ABC):
|
|||||||
# infos = defaultdict(list)
|
# infos = defaultdict(list)
|
||||||
|
|
||||||
# TODO: @Max Why do we need this configure, states should be part of the model
|
# TODO: @Max Why do we need this configure, states should be part of the model
|
||||||
|
# TODO: Ask Onur if the context distribution needs to be outside the environment
|
||||||
# self.env.configure(context)
|
# self.env.configure(context)
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
info = {}
|
info = {}
|
||||||
@ -77,8 +99,8 @@ class MPWrapper(gym.Wrapper, ABC):
|
|||||||
self.render_mode = mode
|
self.render_mode = mode
|
||||||
self.render_kwargs = kwargs
|
self.render_kwargs = kwargs
|
||||||
|
|
||||||
def __call__(self, actions):
|
# def __call__(self, actions):
|
||||||
return self.step(actions)
|
# return self.step(actions)
|
||||||
# params = np.atleast_2d(params)
|
# params = np.atleast_2d(params)
|
||||||
# rewards = []
|
# rewards = []
|
||||||
# infos = []
|
# infos = []
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -83,4 +82,6 @@ def example_async(n_cpu=4, seed=int('533D', 16)):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# example_mujoco()
|
# example_mujoco()
|
||||||
# example_dmp()
|
# example_dmp()
|
||||||
example_async()
|
# example_async()
|
||||||
|
env = gym.make("alr_envs:HoleReacherDMP-v0", context=0.1)
|
||||||
|
print()
|
Loading…
Reference in New Issue
Block a user