updates
This commit is contained in:
parent
77d0cbd00a
commit
0916daf3b5
@ -54,7 +54,7 @@ class BallInACupReward(alr_reward_fct.AlrReward):
|
|||||||
action_cost = np.sum(np.square(action))
|
action_cost = np.sum(np.square(action))
|
||||||
|
|
||||||
if self.check_collision(sim):
|
if self.check_collision(sim):
|
||||||
reward = - 1e-5 * action_cost - 1000
|
reward = - 1e-4 * action_cost - 1000
|
||||||
return reward, False, True
|
return reward, False, True
|
||||||
|
|
||||||
if step == self.sim_time - 1:
|
if step == self.sim_time - 1:
|
||||||
@ -62,10 +62,10 @@ class BallInACupReward(alr_reward_fct.AlrReward):
|
|||||||
dist_final = self.dists_final[-1]
|
dist_final = self.dists_final[-1]
|
||||||
|
|
||||||
cost = 0.5 * min_dist + 0.5 * dist_final
|
cost = 0.5 * min_dist + 0.5 * dist_final
|
||||||
reward = np.exp(-2 * cost) - 1e-5 * action_cost
|
reward = np.exp(-2 * cost) - 1e-4 * action_cost
|
||||||
success = dist_final < 0.05 and ball_in_cup
|
success = dist_final < 0.05 and ball_in_cup
|
||||||
else:
|
else:
|
||||||
reward = - 1e-5 * action_cost
|
reward = - 1e-4 * action_cost
|
||||||
success = False
|
success = False
|
||||||
|
|
||||||
return reward, success, False
|
return reward, success, False
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from alr_envs.utils.dmp_env_wrapper import DmpEnvWrapper
|
from alr_envs.utils.detpmp_env_wrapper import DetPMPEnvWrapper
|
||||||
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
|
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
|
||||||
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_simple import ALRBallInACupEnv as ALRBallInACupEnvSimple
|
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_simple import ALRBallInACupEnv as ALRBallInACupEnvSimple
|
||||||
|
|
||||||
@ -20,18 +20,18 @@ def make_simple_env(rank, seed=0):
|
|||||||
def _init():
|
def _init():
|
||||||
env = ALRBallInACupEnvSimple()
|
env = ALRBallInACupEnvSimple()
|
||||||
|
|
||||||
env = DmpEnvWrapper(env,
|
env = DetPMPEnvWrapper(env,
|
||||||
policy_type="motor",
|
num_dof=3,
|
||||||
start_pos=env.start_pos[1::2],
|
num_basis=5,
|
||||||
final_pos=env.start_pos[1::2],
|
width=0.005,
|
||||||
num_dof=3,
|
policy_type="motor",
|
||||||
num_basis=8,
|
start_pos=env.start_pos[1::2],
|
||||||
duration=3.5,
|
duration=3.5,
|
||||||
alpha_phase=3,
|
post_traj_time=4.5,
|
||||||
post_traj_time=4.5,
|
dt=env.dt,
|
||||||
dt=env.dt,
|
weights_scale=0.1,
|
||||||
learn_goal=False,
|
zero_centered=True
|
||||||
weights_scale=50)
|
)
|
||||||
|
|
||||||
env.seed(seed + rank)
|
env.seed(seed + rank)
|
||||||
return env
|
return env
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
|
from alr_envs.utils.policies import get_policy_class
|
||||||
from mp_lib import det_promp
|
from mp_lib import det_promp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
|
|
||||||
|
|
||||||
class DetPMPEnvWrapperBase(gym.Wrapper):
|
class DetPMPEnvWrapper(gym.Wrapper):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
env,
|
env,
|
||||||
num_dof,
|
num_dof,
|
||||||
@ -13,13 +14,15 @@ class DetPMPEnvWrapperBase(gym.Wrapper):
|
|||||||
duration=1,
|
duration=1,
|
||||||
dt=0.01,
|
dt=0.01,
|
||||||
post_traj_time=0.,
|
post_traj_time=0.,
|
||||||
policy=None,
|
policy_type=None,
|
||||||
weights_scale=1):
|
weights_scale=1,
|
||||||
super(DetPMPEnvWrapperBase, self).__init__(env)
|
zero_centered=False):
|
||||||
|
super(DetPMPEnvWrapper, self).__init__(env)
|
||||||
self.num_dof = num_dof
|
self.num_dof = num_dof
|
||||||
self.num_basis = num_basis
|
self.num_basis = num_basis
|
||||||
self.dim = num_dof * num_basis
|
self.dim = num_dof * num_basis
|
||||||
self.pmp = det_promp.DeterministicProMP(n_basis=num_basis, width=width, off=0.01)
|
self.pmp = det_promp.DeterministicProMP(n_basis=num_basis, n_dof=num_dof, width=width, off=0.01,
|
||||||
|
zero_centered=zero_centered)
|
||||||
weights = np.zeros(shape=(num_basis, num_dof))
|
weights = np.zeros(shape=(num_basis, num_dof))
|
||||||
self.pmp.set_weights(duration, weights)
|
self.pmp.set_weights(duration, weights)
|
||||||
self.weights_scale = weights_scale
|
self.weights_scale = weights_scale
|
||||||
@ -29,52 +32,45 @@ class DetPMPEnvWrapperBase(gym.Wrapper):
|
|||||||
self.post_traj_steps = int(post_traj_time / dt)
|
self.post_traj_steps = int(post_traj_time / dt)
|
||||||
|
|
||||||
self.start_pos = start_pos
|
self.start_pos = start_pos
|
||||||
|
self.zero_centered = zero_centered
|
||||||
|
|
||||||
self.policy = policy
|
policy_class = get_policy_class(policy_type)
|
||||||
|
self.policy = policy_class(env)
|
||||||
|
|
||||||
def __call__(self, params):
|
def __call__(self, params, contexts=None):
|
||||||
params = np.atleast_2d(params)
|
params = np.atleast_2d(params)
|
||||||
observations = []
|
|
||||||
rewards = []
|
rewards = []
|
||||||
dones = []
|
|
||||||
infos = []
|
infos = []
|
||||||
for p in params:
|
for p, c in zip(params, contexts):
|
||||||
observation, reward, done, info = self.rollout(p)
|
reward, info = self.rollout(p, c)
|
||||||
observations.append(observation)
|
|
||||||
rewards.append(reward)
|
rewards.append(reward)
|
||||||
dones.append(done)
|
|
||||||
infos.append(info)
|
infos.append(info)
|
||||||
|
|
||||||
return np.array(rewards), infos
|
return np.array(rewards), infos
|
||||||
|
|
||||||
def rollout(self, params, render=False):
|
def rollout(self, params, context=None, render=False):
|
||||||
""" This function generates a trajectory based on a DMP and then does the usual loop over reset and step"""
|
""" This function generates a trajectory based on a DMP and then does the usual loop over reset and step"""
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class DetPMPEnvWrapperPD(DetPMPEnvWrapperBase):
|
|
||||||
"""
|
|
||||||
Wrapper for gym environments which creates a trajectory in joint velocity space
|
|
||||||
"""
|
|
||||||
def rollout(self, params, render=False):
|
|
||||||
params = np.reshape(params, newshape=(self.num_basis, self.num_dof)) * self.weights_scale
|
params = np.reshape(params, newshape=(self.num_basis, self.num_dof)) * self.weights_scale
|
||||||
self.pmp.set_weights(self.duration, params)
|
self.pmp.set_weights(self.duration, params)
|
||||||
t, des_pos, des_vel, des_acc = self.pmp.compute_trajectory(1/self.dt, 1.)
|
t, des_pos, des_vel, des_acc = self.pmp.compute_trajectory(1 / self.dt, 1.)
|
||||||
des_pos += self.start_pos[None, :]
|
if self.zero_centered:
|
||||||
|
des_pos += self.start_pos[None, :]
|
||||||
|
|
||||||
if self.post_traj_steps > 0:
|
if self.post_traj_steps > 0:
|
||||||
des_pos = np.vstack([des_pos, np.tile(des_pos[-1, :], [self.post_traj_steps, 1])])
|
des_pos = np.vstack([des_pos, np.tile(des_pos[-1, :], [self.post_traj_steps, 1])])
|
||||||
des_vel = np.vstack([des_vel, np.zeros(shape=(self.post_traj_steps, self.num_dof))])
|
des_vel = np.vstack([des_vel, np.zeros(shape=(self.post_traj_steps, self.num_dof))])
|
||||||
|
|
||||||
self._trajectory = des_pos
|
self._trajectory = des_pos
|
||||||
|
self._velocity = des_vel
|
||||||
|
|
||||||
rews = []
|
rews = []
|
||||||
infos = []
|
infos = []
|
||||||
|
|
||||||
|
self.env.configure(context)
|
||||||
self.env.reset()
|
self.env.reset()
|
||||||
|
|
||||||
for t, pos_vel in enumerate(zip(des_pos, des_vel)):
|
for t, pos_vel in enumerate(zip(des_pos, des_vel)):
|
||||||
ac = self.policy.get_action(self.env, pos_vel[0], pos_vel[1])
|
ac = self.policy.get_action(pos_vel[0], pos_vel[1])
|
||||||
obs, rew, done, info = self.env.step(ac)
|
obs, rew, done, info = self.env.step(ac)
|
||||||
rews.append(rew)
|
rews.append(rew)
|
||||||
infos.append(info)
|
infos.append(info)
|
||||||
@ -85,4 +81,5 @@ class DetPMPEnvWrapperPD(DetPMPEnvWrapperBase):
|
|||||||
|
|
||||||
reward = np.sum(rews)
|
reward = np.sum(rews)
|
||||||
|
|
||||||
return obs, reward, done, info
|
return reward, info
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from alr_envs.utils.policies import get_policy_class
|
||||||
from mp_lib.phase import ExpDecayPhaseGenerator
|
from mp_lib.phase import ExpDecayPhaseGenerator
|
||||||
from mp_lib.basis import DMPBasisGenerator
|
from mp_lib.basis import DMPBasisGenerator
|
||||||
from mp_lib import dmps
|
from mp_lib import dmps
|
||||||
@ -5,18 +6,6 @@ import numpy as np
|
|||||||
import gym
|
import gym
|
||||||
|
|
||||||
|
|
||||||
def get_policy_class(policy_type):
|
|
||||||
if policy_type == "motor":
|
|
||||||
from alr_envs.utils.policies import PDController
|
|
||||||
return PDController
|
|
||||||
elif policy_type == "velocity":
|
|
||||||
from alr_envs.utils.policies import VelController
|
|
||||||
return VelController
|
|
||||||
elif policy_type == "position":
|
|
||||||
from alr_envs.utils.policies import PosController
|
|
||||||
return PosController
|
|
||||||
|
|
||||||
|
|
||||||
class DmpEnvWrapper(gym.Wrapper):
|
class DmpEnvWrapper(gym.Wrapper):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
env,
|
env,
|
||||||
|
@ -35,3 +35,12 @@ class PDController(BaseController):
|
|||||||
des_vel = self.env.extend_des_vel(des_vel)
|
des_vel = self.env.extend_des_vel(des_vel)
|
||||||
trq = self.p_gains * (des_pos - cur_pos) + self.d_gains * (des_vel - cur_vel)
|
trq = self.p_gains * (des_pos - cur_pos) + self.d_gains * (des_vel - cur_vel)
|
||||||
return trq
|
return trq
|
||||||
|
|
||||||
|
|
||||||
|
def get_policy_class(policy_type):
|
||||||
|
if policy_type == "motor":
|
||||||
|
return PDController
|
||||||
|
elif policy_type == "velocity":
|
||||||
|
return VelController
|
||||||
|
elif policy_type == "position":
|
||||||
|
return PosController
|
||||||
|
Loading…
Reference in New Issue
Block a user