merge api support
This commit is contained in:
commit
af8e868309
2
.gitignore
vendored
2
.gitignore
vendored
@ -109,3 +109,5 @@ venv.bak/
|
|||||||
|
|
||||||
#configs
|
#configs
|
||||||
/configs/db.cfg
|
/configs/db.cfg
|
||||||
|
legacy/
|
||||||
|
MUJOCO_LOG.TXT
|
||||||
|
@ -1,15 +0,0 @@
|
|||||||
Fri Aug 28 14:41:56 2020
|
|
||||||
ERROR: GLEW initalization error: Missing GL version
|
|
||||||
|
|
||||||
Fri Aug 28 14:59:14 2020
|
|
||||||
ERROR: GLEW initalization error: Missing GL version
|
|
||||||
|
|
||||||
Fri Aug 28 15:03:43 2020
|
|
||||||
ERROR: GLEW initalization error: Missing GL version
|
|
||||||
|
|
||||||
Fri Aug 28 15:07:03 2020
|
|
||||||
ERROR: GLEW initalization error: Missing GL version
|
|
||||||
|
|
||||||
Fri Aug 28 15:15:01 2020
|
|
||||||
ERROR: GLEW initalization error: Missing GL version
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
|||||||
|
import numpy as np
|
||||||
from gym.envs.registration import register
|
from gym.envs.registration import register
|
||||||
|
|
||||||
from alr_envs.stochastic_search.functions.f_rosenbrock import Rosenbrock
|
from alr_envs.stochastic_search.functions.f_rosenbrock import Rosenbrock
|
||||||
@ -321,7 +322,9 @@ register(
|
|||||||
"bandwidth_factor": 2.5,
|
"bandwidth_factor": 2.5,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 100,
|
"weights_scale": 100,
|
||||||
"return_to_start": True
|
"return_to_start": True,
|
||||||
|
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||||
|
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -339,7 +342,9 @@ register(
|
|||||||
"bandwidth_factor": 2.5,
|
"bandwidth_factor": 2.5,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 100,
|
"weights_scale": 100,
|
||||||
"return_to_start": True
|
"return_to_start": True,
|
||||||
|
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||||
|
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -357,7 +362,9 @@ register(
|
|||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"zero_goal": True
|
"zero_goal": True,
|
||||||
|
"p_gains": np.array([4./3., 2.4, 2.5, 5./3., 2., 2., 1.25]),
|
||||||
|
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -374,7 +381,9 @@ register(
|
|||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"zero_goal": True
|
"zero_goal": True,
|
||||||
|
"p_gains": np.array([4./3., 2.4, 2.5, 5./3., 2., 2., 1.25]),
|
||||||
|
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -392,7 +401,9 @@ register(
|
|||||||
"bandwidth_factor": 2.5,
|
"bandwidth_factor": 2.5,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 50,
|
"weights_scale": 50,
|
||||||
"goal_scale": 0.1
|
"goal_scale": 0.1,
|
||||||
|
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||||
|
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from gym.utils import seeding
|
|||||||
from matplotlib import patches
|
from matplotlib import patches
|
||||||
|
|
||||||
from alr_envs.classic_control.utils import check_self_collision
|
from alr_envs.classic_control.utils import check_self_collision
|
||||||
from alr_envs.utils.mps.mp_environments import AlrEnv
|
from alr_envs.utils.mps.alr_env import AlrEnv
|
||||||
|
|
||||||
|
|
||||||
class HoleReacherEnv(AlrEnv):
|
class HoleReacherEnv(AlrEnv):
|
||||||
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
from gym import spaces
|
from gym import spaces
|
||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
|
|
||||||
from alr_envs.utils.mps.mp_environments import AlrEnv
|
from alr_envs.utils.mps.alr_env import AlrEnv
|
||||||
|
|
||||||
|
|
||||||
class SimpleReacherEnv(AlrEnv):
|
class SimpleReacherEnv(AlrEnv):
|
||||||
|
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
|
|
||||||
from alr_envs.classic_control.utils import check_self_collision
|
from alr_envs.classic_control.utils import check_self_collision
|
||||||
from alr_envs.utils.mps.mp_environments import AlrEnv
|
from alr_envs.utils.mps.alr_env import AlrEnv
|
||||||
|
|
||||||
|
|
||||||
class ViaPointReacher(AlrEnv):
|
class ViaPointReacher(AlrEnv):
|
||||||
|
@ -7,7 +7,9 @@ from gym import error, spaces
|
|||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from os import path
|
from os import path
|
||||||
import gym
|
|
||||||
|
from alr_envs.utils.mps.alr_env import AlrEnv
|
||||||
|
from alr_envs.utils.positional_env import PositionalEnv
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import mujoco_py
|
import mujoco_py
|
||||||
@ -33,7 +35,7 @@ def convert_observation_to_space(observation):
|
|||||||
return space
|
return space
|
||||||
|
|
||||||
|
|
||||||
class AlrMujocoEnv(gym.Env):
|
class AlrMujocoEnv(PositionalEnv, AlrEnv):
|
||||||
"""
|
"""
|
||||||
Superclass for all MuJoCo environments.
|
Superclass for all MuJoCo environments.
|
||||||
"""
|
"""
|
||||||
@ -44,7 +46,7 @@ class AlrMujocoEnv(gym.Env):
|
|||||||
Args:
|
Args:
|
||||||
model_path: path to xml file
|
model_path: path to xml file
|
||||||
n_substeps: how many steps mujoco does per call to env.step
|
n_substeps: how many steps mujoco does per call to env.step
|
||||||
use_servo: use actuator defined in xml, use False for direct torque control
|
apply_gravity_comp: Whether gravity compensation should be active
|
||||||
"""
|
"""
|
||||||
if model_path.startswith("/"):
|
if model_path.startswith("/"):
|
||||||
fullpath = model_path
|
fullpath = model_path
|
||||||
@ -73,10 +75,6 @@ class AlrMujocoEnv(gym.Env):
|
|||||||
|
|
||||||
self._set_action_space()
|
self._set_action_space()
|
||||||
|
|
||||||
# action = self.action_space.sample()
|
|
||||||
# observation, _reward, done, _info = self.step(action)
|
|
||||||
# assert not done
|
|
||||||
|
|
||||||
observation = self._get_obs() # TODO: is calling get_obs enough? should we call reset, or even step?
|
observation = self._get_obs() # TODO: is calling get_obs enough? should we call reset, or even step?
|
||||||
|
|
||||||
self._set_observation_space(observation)
|
self._set_observation_space(observation)
|
||||||
@ -204,7 +202,7 @@ class AlrMujocoEnv(gym.Env):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
self.sim.step()
|
self.sim.step()
|
||||||
except mujoco_py.builder.MujocoException as e:
|
except mujoco_py.builder.MujocoException:
|
||||||
error_in_sim = True
|
error_in_sim = True
|
||||||
|
|
||||||
return error_in_sim
|
return error_in_sim
|
||||||
|
@ -16,8 +16,6 @@ class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
|||||||
self._q_vel = []
|
self._q_vel = []
|
||||||
# self.weight_matrix_scale = 50
|
# self.weight_matrix_scale = 50
|
||||||
self.max_ctrl = np.array([150., 125., 40., 60., 5., 5., 2.])
|
self.max_ctrl = np.array([150., 125., 40., 60., 5., 5., 2.])
|
||||||
self.p_gains = 1 / self.max_ctrl * np.array([200, 300, 100, 100, 10, 10, 2.5])
|
|
||||||
self.d_gains = 1 / self.max_ctrl * np.array([7, 15, 5, 2.5, 0.3, 0.3, 0.05])
|
|
||||||
|
|
||||||
self.j_min = np.array([-2.6, -1.985, -2.8, -0.9, -4.55, -1.5707, -2.7])
|
self.j_min = np.array([-2.6, -1.985, -2.8, -0.9, -4.55, -1.5707, -2.7])
|
||||||
self.j_max = np.array([2.6, 1.985, 2.8, 3.14159, 1.25, 1.5707, 2.7])
|
self.j_max = np.array([2.6, 1.985, 2.8, 3.14159, 1.25, 1.5707, 2.7])
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from alr_envs.utils.mps.dmp_wrapper import DmpWrapper
|
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
||||||
from alr_envs.utils.mps.detpmp_wrapper import DetPMPWrapper
|
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
||||||
import gym
|
import gym
|
||||||
from gym.vector.utils import write_to_shared_memory
|
from gym.vector.utils import write_to_shared_memory
|
||||||
import sys
|
import sys
|
||||||
|
@ -91,20 +91,21 @@ class AlrContextualMpEnvSampler:
|
|||||||
|
|
||||||
repeat = int(np.ceil(n_samples / self.env.num_envs))
|
repeat = int(np.ceil(n_samples / self.env.num_envs))
|
||||||
vals = defaultdict(list)
|
vals = defaultdict(list)
|
||||||
|
|
||||||
|
obs = self.env.reset()
|
||||||
for i in range(repeat):
|
for i in range(repeat):
|
||||||
new_contexts = self.env.reset()
|
vals['obs'].append(obs)
|
||||||
vals['new_contexts'].append(new_contexts)
|
new_samples, new_contexts = dist.sample(obs)
|
||||||
new_samples, new_contexts = dist.sample(new_contexts)
|
|
||||||
vals['new_samples'].append(new_samples)
|
vals['new_samples'].append(new_samples)
|
||||||
|
|
||||||
obs, reward, done, info = self.env.step(new_samples)
|
obs, reward, done, info = self.env.step(new_samples)
|
||||||
vals['obs'].append(obs)
|
|
||||||
vals['reward'].append(reward)
|
vals['reward'].append(reward)
|
||||||
vals['done'].append(done)
|
vals['done'].append(done)
|
||||||
vals['info'].append(info)
|
vals['info'].append(info)
|
||||||
|
|
||||||
# do not return values above threshold
|
# do not return values above threshold
|
||||||
return np.vstack(vals['new_samples'])[:n_samples], np.vstack(vals['new_contexts'])[:n_samples], \
|
return np.vstack(vals['new_samples'])[:n_samples], \
|
||||||
np.vstack(vals['obs'])[:n_samples], np.hstack(vals['reward'])[:n_samples], \
|
np.vstack(vals['obs'])[:n_samples], np.hstack(vals['reward'])[:n_samples], \
|
||||||
_flatten_list(vals['done'])[:n_samples], _flatten_list(vals['info'])[:n_samples]
|
_flatten_list(vals['done'])[:n_samples], _flatten_list(vals['info'])[:n_samples]
|
||||||
|
|
||||||
|
@ -1,42 +0,0 @@
|
|||||||
import gym
|
|
||||||
import numpy as np
|
|
||||||
from mp_lib import det_promp
|
|
||||||
|
|
||||||
from alr_envs.utils.mps.mp_environments import AlrEnv
|
|
||||||
from alr_envs.utils.mps.mp_wrapper import MPWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class DetPMPWrapper(MPWrapper):
|
|
||||||
def __init__(self, env: AlrEnv, num_dof: int, num_basis: int, width: float, duration: float = 1, dt: float = 0.01,
|
|
||||||
post_traj_time: float = 0., policy_type: str = None, weights_scale: float = 1.,
|
|
||||||
zero_start: bool = False, zero_goal: bool = False, **mp_kwargs):
|
|
||||||
self.duration = duration # seconds
|
|
||||||
|
|
||||||
dt = env.dt if hasattr(env, "dt") else dt
|
|
||||||
assert dt is not None
|
|
||||||
self.dt = dt
|
|
||||||
|
|
||||||
super().__init__(env, num_dof, dt, duration, post_traj_time, policy_type, weights_scale, num_basis=num_basis,
|
|
||||||
width=width, zero_start=zero_start, zero_goal=zero_goal, **mp_kwargs)
|
|
||||||
|
|
||||||
action_bounds = np.inf * np.ones((self.mp.n_basis * self.mp.n_dof))
|
|
||||||
self.action_space = gym.spaces.Box(low=-action_bounds, high=action_bounds, dtype=np.float32)
|
|
||||||
|
|
||||||
def initialize_mp(self, num_dof: int, duration: int, dt: float, num_basis: int = 5, width: float = None,
|
|
||||||
off: float = 0.01, zero_start: bool = False, zero_goal: bool = False):
|
|
||||||
pmp = det_promp.DeterministicProMP(n_basis=num_basis, n_dof=num_dof, width=width, off=off,
|
|
||||||
zero_start=zero_start, zero_goal=zero_goal)
|
|
||||||
|
|
||||||
weights = np.zeros(shape=(num_basis, num_dof))
|
|
||||||
pmp.set_weights(duration, weights)
|
|
||||||
|
|
||||||
return pmp
|
|
||||||
|
|
||||||
def mp_rollout(self, action):
|
|
||||||
params = np.reshape(action, newshape=(self.mp.n_basis, self.mp.n_dof)) * self.weights_scale
|
|
||||||
self.mp.set_weights(self.duration, params)
|
|
||||||
_, des_pos, des_vel, _ = self.mp.compute_trajectory(1 / self.dt, 1.)
|
|
||||||
if self.mp.zero_start:
|
|
||||||
des_pos += self.env.start_pos[None, :]
|
|
||||||
|
|
||||||
return des_pos, des_vel
|
|
@ -1,76 +0,0 @@
|
|||||||
import gym
|
|
||||||
import numpy as np
|
|
||||||
from mp_lib import dmps
|
|
||||||
from mp_lib.basis import DMPBasisGenerator
|
|
||||||
from mp_lib.phase import ExpDecayPhaseGenerator
|
|
||||||
|
|
||||||
from alr_envs.utils.mps.mp_environments import AlrEnv
|
|
||||||
from alr_envs.utils.mps.mp_wrapper import MPWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class DmpWrapper(MPWrapper):
|
|
||||||
|
|
||||||
def __init__(self, env: AlrEnv, num_dof: int, num_basis: int,
|
|
||||||
duration: int = 1, alpha_phase: float = 2., dt: float = None,
|
|
||||||
learn_goal: bool = False, post_traj_time: float = 0.,
|
|
||||||
weights_scale: float = 1., goal_scale: float = 1., bandwidth_factor: float = 3.,
|
|
||||||
policy_type: str = None, render_mode: str = None):
|
|
||||||
|
|
||||||
"""
|
|
||||||
This Wrapper generates a trajectory based on a DMP and will only return episodic performances.
|
|
||||||
Args:
|
|
||||||
env:
|
|
||||||
num_dof:
|
|
||||||
num_basis:
|
|
||||||
duration:
|
|
||||||
alpha_phase:
|
|
||||||
dt:
|
|
||||||
learn_goal:
|
|
||||||
post_traj_time:
|
|
||||||
policy_type:
|
|
||||||
weights_scale:
|
|
||||||
goal_scale:
|
|
||||||
"""
|
|
||||||
self.learn_goal = learn_goal
|
|
||||||
dt = env.dt if hasattr(env, "dt") else dt
|
|
||||||
assert dt is not None
|
|
||||||
self.t = np.linspace(0, duration, int(duration / dt))
|
|
||||||
self.goal_scale = goal_scale
|
|
||||||
|
|
||||||
super().__init__(env, num_dof, dt, duration, post_traj_time, policy_type, weights_scale, render_mode,
|
|
||||||
num_basis=num_basis, alpha_phase=alpha_phase, bandwidth_factor=bandwidth_factor)
|
|
||||||
|
|
||||||
action_bounds = np.inf * np.ones((np.prod(self.mp.weights.shape) + (num_dof if learn_goal else 0)))
|
|
||||||
self.action_space = gym.spaces.Box(low=-action_bounds, high=action_bounds, dtype=np.float32)
|
|
||||||
|
|
||||||
def initialize_mp(self, num_dof: int, duration: int, dt: float, num_basis: int = 5, alpha_phase: float = 2.,
|
|
||||||
bandwidth_factor: int = 3):
|
|
||||||
|
|
||||||
phase_generator = ExpDecayPhaseGenerator(alpha_phase=alpha_phase, duration=duration)
|
|
||||||
basis_generator = DMPBasisGenerator(phase_generator, duration=duration, num_basis=num_basis,
|
|
||||||
basis_bandwidth_factor=bandwidth_factor)
|
|
||||||
|
|
||||||
dmp = dmps.DMP(num_dof=num_dof, basis_generator=basis_generator, phase_generator=phase_generator,
|
|
||||||
duration=duration, dt=dt)
|
|
||||||
|
|
||||||
return dmp
|
|
||||||
|
|
||||||
def goal_and_weights(self, params):
|
|
||||||
assert params.shape[-1] == self.action_space.shape[0]
|
|
||||||
params = np.atleast_2d(params)
|
|
||||||
|
|
||||||
if self.learn_goal:
|
|
||||||
goal_pos = params[0, -self.mp.num_dimensions:] # [num_dof]
|
|
||||||
params = params[:, :-self.mp.num_dimensions] # [1,num_dof]
|
|
||||||
else:
|
|
||||||
goal_pos = self.env.goal_pos
|
|
||||||
assert goal_pos is not None
|
|
||||||
|
|
||||||
weight_matrix = np.reshape(params, self.mp.weights.shape) # [num_basis, num_dof]
|
|
||||||
return goal_pos * self.goal_scale, weight_matrix * self.weights_scale
|
|
||||||
|
|
||||||
def mp_rollout(self, action):
|
|
||||||
self.mp.dmp_start_pos = self.env.start_pos
|
|
||||||
goal_pos, weight_matrix = self.goal_and_weights(action)
|
|
||||||
self.mp.set_weights(weight_matrix, goal_pos)
|
|
||||||
return self.mp.reference_trajectory(self.t)
|
|
@ -1,33 +0,0 @@
|
|||||||
from abc import abstractmethod
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import gym
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class AlrEnv(gym.Env):
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def active_obs(self):
|
|
||||||
"""Returns boolean mask for each observation entry
|
|
||||||
whether the observation is returned for the contextual case or not.
|
|
||||||
This effectively allows to filter unwanted or unnecessary observations from the full step-based case.
|
|
||||||
"""
|
|
||||||
return np.ones(self.observation_space.shape, dtype=bool)
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def start_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
"""
|
|
||||||
Returns the starting position of the joints
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def goal_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
"""
|
|
||||||
Returns the current final position of the joints for the MP.
|
|
||||||
By default this returns the starting position.
|
|
||||||
"""
|
|
||||||
return self.start_pos
|
|
@ -1,113 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import gym
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from alr_envs.utils.mps.mp_environments import AlrEnv
|
|
||||||
from alr_envs.utils.policies import get_policy_class
|
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(gym.Wrapper, ABC):
|
|
||||||
|
|
||||||
def __init__(self, env: AlrEnv, num_dof: int, dt: float, duration: float = 1, post_traj_time: float = 0.,
|
|
||||||
policy_type: str = None, weights_scale: float = 1., render_mode: str = None, **mp_kwargs):
|
|
||||||
super().__init__(env)
|
|
||||||
|
|
||||||
# adjust observation space to reduce version
|
|
||||||
obs_sp = self.env.observation_space
|
|
||||||
self.observation_space = gym.spaces.Box(low=obs_sp.low[self.env.active_obs],
|
|
||||||
high=obs_sp.high[self.env.active_obs],
|
|
||||||
dtype=obs_sp.dtype)
|
|
||||||
|
|
||||||
assert dt is not None # this should never happen as MPWrapper is a base class
|
|
||||||
self.post_traj_steps = int(post_traj_time / dt)
|
|
||||||
|
|
||||||
self.mp = self.initialize_mp(num_dof, duration, dt, **mp_kwargs)
|
|
||||||
self.weights_scale = weights_scale
|
|
||||||
|
|
||||||
policy_class = get_policy_class(policy_type)
|
|
||||||
self.policy = policy_class(env)
|
|
||||||
|
|
||||||
# rendering
|
|
||||||
self.render_mode = render_mode
|
|
||||||
self.render_kwargs = {}
|
|
||||||
|
|
||||||
# TODO: @Max I think this should not be in this class, this functionality should be part of your sampler.
|
|
||||||
def __call__(self, params, contexts=None):
|
|
||||||
"""
|
|
||||||
Can be used to provide a batch of parameter sets
|
|
||||||
"""
|
|
||||||
params = np.atleast_2d(params)
|
|
||||||
obs = []
|
|
||||||
rewards = []
|
|
||||||
dones = []
|
|
||||||
infos = []
|
|
||||||
# for p, c in zip(params, contexts):
|
|
||||||
for p in params:
|
|
||||||
# self.configure(c)
|
|
||||||
ob, reward, done, info = self.step(p)
|
|
||||||
obs.append(ob)
|
|
||||||
rewards.append(reward)
|
|
||||||
dones.append(done)
|
|
||||||
infos.append(info)
|
|
||||||
|
|
||||||
return obs, np.array(rewards), dones, infos
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
return self.env.reset()[self.env.active_obs]
|
|
||||||
|
|
||||||
def step(self, action: np.ndarray):
|
|
||||||
""" This function generates a trajectory based on a DMP and then does the usual loop over reset and step"""
|
|
||||||
trajectory, velocity = self.mp_rollout(action)
|
|
||||||
|
|
||||||
if self.post_traj_steps > 0:
|
|
||||||
trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])])
|
|
||||||
velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.mp.n_dof))])
|
|
||||||
|
|
||||||
self._trajectory = trajectory
|
|
||||||
# self._velocity = velocity
|
|
||||||
|
|
||||||
rewards = 0
|
|
||||||
infos = defaultdict(list)
|
|
||||||
# create random obs as the reset function is called externally
|
|
||||||
obs = self.env.observation_space.sample()
|
|
||||||
|
|
||||||
for t, pos_vel in enumerate(zip(trajectory, velocity)):
|
|
||||||
ac = self.policy.get_action(pos_vel[0], pos_vel[1])
|
|
||||||
obs, rew, done, info = self.env.step(ac)
|
|
||||||
rewards += rew
|
|
||||||
# TODO return all dicts?
|
|
||||||
[infos[k].append(v) for k, v in info.items()]
|
|
||||||
if self.render_mode:
|
|
||||||
self.env.render(mode=self.render_mode, **self.render_kwargs)
|
|
||||||
if done:
|
|
||||||
break
|
|
||||||
|
|
||||||
done = True
|
|
||||||
return obs[self.env.active_obs], rewards, done, infos
|
|
||||||
|
|
||||||
def render(self, mode='human', **kwargs):
|
|
||||||
"""Only set render options here, such that they can be used during the rollout.
|
|
||||||
This only needs to be called once"""
|
|
||||||
self.render_mode = mode
|
|
||||||
self.render_kwargs = kwargs
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def mp_rollout(self, action):
|
|
||||||
"""
|
|
||||||
Generate trajectory and velocity based on the MP
|
|
||||||
Returns:
|
|
||||||
trajectory/positions, velocity
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def initialize_mp(self, num_dof: int, duration: float, dt: float, **kwargs):
|
|
||||||
"""
|
|
||||||
Create respective instance of MP
|
|
||||||
Returns:
|
|
||||||
MP instance
|
|
||||||
"""
|
|
||||||
|
|
||||||
raise NotImplementedError
|
|
@ -1,48 +0,0 @@
|
|||||||
from gym import Env
|
|
||||||
|
|
||||||
from alr_envs.mujoco.alr_mujoco_env import AlrMujocoEnv
|
|
||||||
|
|
||||||
|
|
||||||
class BaseController:
|
|
||||||
def __init__(self, env: Env):
|
|
||||||
self.env = env
|
|
||||||
|
|
||||||
def get_action(self, des_pos, des_vel):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class PosController(BaseController):
|
|
||||||
def get_action(self, des_pos, des_vel):
|
|
||||||
return des_pos
|
|
||||||
|
|
||||||
|
|
||||||
class VelController(BaseController):
|
|
||||||
def get_action(self, des_pos, des_vel):
|
|
||||||
return des_vel
|
|
||||||
|
|
||||||
|
|
||||||
class PDController(BaseController):
|
|
||||||
def __init__(self, env: AlrMujocoEnv):
|
|
||||||
self.p_gains = env.p_gains
|
|
||||||
self.d_gains = env.d_gains
|
|
||||||
super(PDController, self).__init__(env)
|
|
||||||
|
|
||||||
def get_action(self, des_pos, des_vel):
|
|
||||||
# TODO: make standardized ALRenv such that all of them have current_pos/vel attributes
|
|
||||||
cur_pos = self.env.current_pos
|
|
||||||
cur_vel = self.env.current_vel
|
|
||||||
if len(des_pos) != len(cur_pos):
|
|
||||||
des_pos = self.env.extend_des_pos(des_pos)
|
|
||||||
if len(des_vel) != len(cur_vel):
|
|
||||||
des_vel = self.env.extend_des_vel(des_vel)
|
|
||||||
trq = self.p_gains * (des_pos - cur_pos) + self.d_gains * (des_vel - cur_vel)
|
|
||||||
return trq
|
|
||||||
|
|
||||||
|
|
||||||
def get_policy_class(policy_type):
|
|
||||||
if policy_type == "motor":
|
|
||||||
return PDController
|
|
||||||
elif policy_type == "velocity":
|
|
||||||
return VelController
|
|
||||||
elif policy_type == "position":
|
|
||||||
return PosController
|
|
Loading…
Reference in New Issue
Block a user