This commit is contained in:
Maximilian Huettenrauch 2021-02-15 16:31:34 +01:00
parent 77d0cbd00a
commit 0916daf3b5
5 changed files with 49 additions and 54 deletions

View File

@ -54,7 +54,7 @@ class BallInACupReward(alr_reward_fct.AlrReward):
action_cost = np.sum(np.square(action))
if self.check_collision(sim):
reward = - 1e-5 * action_cost - 1000
reward = - 1e-4 * action_cost - 1000
return reward, False, True
if step == self.sim_time - 1:
@ -62,10 +62,10 @@ class BallInACupReward(alr_reward_fct.AlrReward):
dist_final = self.dists_final[-1]
cost = 0.5 * min_dist + 0.5 * dist_final
reward = np.exp(-2 * cost) - 1e-5 * action_cost
reward = np.exp(-2 * cost) - 1e-4 * action_cost
success = dist_final < 0.05 and ball_in_cup
else:
reward = - 1e-5 * action_cost
reward = - 1e-4 * action_cost
success = False
return reward, success, False

View File

@ -1,4 +1,4 @@
from alr_envs.utils.dmp_env_wrapper import DmpEnvWrapper
from alr_envs.utils.detpmp_env_wrapper import DetPMPEnvWrapper
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_simple import ALRBallInACupEnv as ALRBallInACupEnvSimple
@ -20,18 +20,18 @@ def make_simple_env(rank, seed=0):
def _init():
env = ALRBallInACupEnvSimple()
env = DmpEnvWrapper(env,
env = DetPMPEnvWrapper(env,
num_dof=3,
num_basis=5,
width=0.005,
policy_type="motor",
start_pos=env.start_pos[1::2],
final_pos=env.start_pos[1::2],
num_dof=3,
num_basis=8,
duration=3.5,
alpha_phase=3,
post_traj_time=4.5,
dt=env.dt,
learn_goal=False,
weights_scale=50)
weights_scale=0.1,
zero_centered=True
)
env.seed(seed + rank)
return env

View File

@ -1,9 +1,10 @@
from alr_envs.utils.policies import get_policy_class
from mp_lib import det_promp
import numpy as np
import gym
class DetPMPEnvWrapperBase(gym.Wrapper):
class DetPMPEnvWrapper(gym.Wrapper):
def __init__(self,
env,
num_dof,
@ -13,13 +14,15 @@ class DetPMPEnvWrapperBase(gym.Wrapper):
duration=1,
dt=0.01,
post_traj_time=0.,
policy=None,
weights_scale=1):
super(DetPMPEnvWrapperBase, self).__init__(env)
policy_type=None,
weights_scale=1,
zero_centered=False):
super(DetPMPEnvWrapper, self).__init__(env)
self.num_dof = num_dof
self.num_basis = num_basis
self.dim = num_dof * num_basis
self.pmp = det_promp.DeterministicProMP(n_basis=num_basis, width=width, off=0.01)
self.pmp = det_promp.DeterministicProMP(n_basis=num_basis, n_dof=num_dof, width=width, off=0.01,
zero_centered=zero_centered)
weights = np.zeros(shape=(num_basis, num_dof))
self.pmp.set_weights(duration, weights)
self.weights_scale = weights_scale
@ -29,37 +32,28 @@ class DetPMPEnvWrapperBase(gym.Wrapper):
self.post_traj_steps = int(post_traj_time / dt)
self.start_pos = start_pos
self.zero_centered = zero_centered
self.policy = policy
policy_class = get_policy_class(policy_type)
self.policy = policy_class(env)
def __call__(self, params):
def __call__(self, params, contexts=None):
params = np.atleast_2d(params)
observations = []
rewards = []
dones = []
infos = []
for p in params:
observation, reward, done, info = self.rollout(p)
observations.append(observation)
for p, c in zip(params, contexts):
reward, info = self.rollout(p, c)
rewards.append(reward)
dones.append(done)
infos.append(info)
return np.array(rewards), infos
def rollout(self, params, render=False):
def rollout(self, params, context=None, render=False):
""" This function generates a trajectory based on a DMP and then does the usual loop over reset and step"""
raise NotImplementedError
class DetPMPEnvWrapperPD(DetPMPEnvWrapperBase):
"""
Wrapper for gym environments which creates a trajectory in joint velocity space
"""
def rollout(self, params, render=False):
params = np.reshape(params, newshape=(self.num_basis, self.num_dof)) * self.weights_scale
self.pmp.set_weights(self.duration, params)
t, des_pos, des_vel, des_acc = self.pmp.compute_trajectory(1 / self.dt, 1.)
if self.zero_centered:
des_pos += self.start_pos[None, :]
if self.post_traj_steps > 0:
@ -67,14 +61,16 @@ class DetPMPEnvWrapperPD(DetPMPEnvWrapperBase):
des_vel = np.vstack([des_vel, np.zeros(shape=(self.post_traj_steps, self.num_dof))])
self._trajectory = des_pos
self._velocity = des_vel
rews = []
infos = []
self.env.configure(context)
self.env.reset()
for t, pos_vel in enumerate(zip(des_pos, des_vel)):
ac = self.policy.get_action(self.env, pos_vel[0], pos_vel[1])
ac = self.policy.get_action(pos_vel[0], pos_vel[1])
obs, rew, done, info = self.env.step(ac)
rews.append(rew)
infos.append(info)
@ -85,4 +81,5 @@ class DetPMPEnvWrapperPD(DetPMPEnvWrapperBase):
reward = np.sum(rews)
return obs, reward, done, info
return reward, info

View File

@ -1,3 +1,4 @@
from alr_envs.utils.policies import get_policy_class
from mp_lib.phase import ExpDecayPhaseGenerator
from mp_lib.basis import DMPBasisGenerator
from mp_lib import dmps
@ -5,18 +6,6 @@ import numpy as np
import gym
def get_policy_class(policy_type):
if policy_type == "motor":
from alr_envs.utils.policies import PDController
return PDController
elif policy_type == "velocity":
from alr_envs.utils.policies import VelController
return VelController
elif policy_type == "position":
from alr_envs.utils.policies import PosController
return PosController
class DmpEnvWrapper(gym.Wrapper):
def __init__(self,
env,

View File

@ -35,3 +35,12 @@ class PDController(BaseController):
des_vel = self.env.extend_des_vel(des_vel)
trq = self.p_gains * (des_pos - cur_pos) + self.d_gains * (des_vel - cur_vel)
return trq
def get_policy_class(policy_type):
if policy_type == "motor":
return PDController
elif policy_type == "velocity":
return VelController
elif policy_type == "position":
return PosController