updates
This commit is contained in:
parent
77d0cbd00a
commit
0916daf3b5
@ -54,7 +54,7 @@ class BallInACupReward(alr_reward_fct.AlrReward):
|
||||
action_cost = np.sum(np.square(action))
|
||||
|
||||
if self.check_collision(sim):
|
||||
reward = - 1e-5 * action_cost - 1000
|
||||
reward = - 1e-4 * action_cost - 1000
|
||||
return reward, False, True
|
||||
|
||||
if step == self.sim_time - 1:
|
||||
@ -62,10 +62,10 @@ class BallInACupReward(alr_reward_fct.AlrReward):
|
||||
dist_final = self.dists_final[-1]
|
||||
|
||||
cost = 0.5 * min_dist + 0.5 * dist_final
|
||||
reward = np.exp(-2 * cost) - 1e-5 * action_cost
|
||||
reward = np.exp(-2 * cost) - 1e-4 * action_cost
|
||||
success = dist_final < 0.05 and ball_in_cup
|
||||
else:
|
||||
reward = - 1e-5 * action_cost
|
||||
reward = - 1e-4 * action_cost
|
||||
success = False
|
||||
|
||||
return reward, success, False
|
||||
|
@ -1,4 +1,4 @@
|
||||
from alr_envs.utils.dmp_env_wrapper import DmpEnvWrapper
|
||||
from alr_envs.utils.detpmp_env_wrapper import DetPMPEnvWrapper
|
||||
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
|
||||
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_simple import ALRBallInACupEnv as ALRBallInACupEnvSimple
|
||||
|
||||
@ -20,18 +20,18 @@ def make_simple_env(rank, seed=0):
|
||||
def _init():
|
||||
env = ALRBallInACupEnvSimple()
|
||||
|
||||
env = DmpEnvWrapper(env,
|
||||
env = DetPMPEnvWrapper(env,
|
||||
num_dof=3,
|
||||
num_basis=5,
|
||||
width=0.005,
|
||||
policy_type="motor",
|
||||
start_pos=env.start_pos[1::2],
|
||||
final_pos=env.start_pos[1::2],
|
||||
num_dof=3,
|
||||
num_basis=8,
|
||||
duration=3.5,
|
||||
alpha_phase=3,
|
||||
post_traj_time=4.5,
|
||||
dt=env.dt,
|
||||
learn_goal=False,
|
||||
weights_scale=50)
|
||||
weights_scale=0.1,
|
||||
zero_centered=True
|
||||
)
|
||||
|
||||
env.seed(seed + rank)
|
||||
return env
|
||||
|
@ -1,9 +1,10 @@
|
||||
from alr_envs.utils.policies import get_policy_class
|
||||
from mp_lib import det_promp
|
||||
import numpy as np
|
||||
import gym
|
||||
|
||||
|
||||
class DetPMPEnvWrapperBase(gym.Wrapper):
|
||||
class DetPMPEnvWrapper(gym.Wrapper):
|
||||
def __init__(self,
|
||||
env,
|
||||
num_dof,
|
||||
@ -13,13 +14,15 @@ class DetPMPEnvWrapperBase(gym.Wrapper):
|
||||
duration=1,
|
||||
dt=0.01,
|
||||
post_traj_time=0.,
|
||||
policy=None,
|
||||
weights_scale=1):
|
||||
super(DetPMPEnvWrapperBase, self).__init__(env)
|
||||
policy_type=None,
|
||||
weights_scale=1,
|
||||
zero_centered=False):
|
||||
super(DetPMPEnvWrapper, self).__init__(env)
|
||||
self.num_dof = num_dof
|
||||
self.num_basis = num_basis
|
||||
self.dim = num_dof * num_basis
|
||||
self.pmp = det_promp.DeterministicProMP(n_basis=num_basis, width=width, off=0.01)
|
||||
self.pmp = det_promp.DeterministicProMP(n_basis=num_basis, n_dof=num_dof, width=width, off=0.01,
|
||||
zero_centered=zero_centered)
|
||||
weights = np.zeros(shape=(num_basis, num_dof))
|
||||
self.pmp.set_weights(duration, weights)
|
||||
self.weights_scale = weights_scale
|
||||
@ -29,37 +32,28 @@ class DetPMPEnvWrapperBase(gym.Wrapper):
|
||||
self.post_traj_steps = int(post_traj_time / dt)
|
||||
|
||||
self.start_pos = start_pos
|
||||
self.zero_centered = zero_centered
|
||||
|
||||
self.policy = policy
|
||||
policy_class = get_policy_class(policy_type)
|
||||
self.policy = policy_class(env)
|
||||
|
||||
def __call__(self, params):
|
||||
def __call__(self, params, contexts=None):
|
||||
params = np.atleast_2d(params)
|
||||
observations = []
|
||||
rewards = []
|
||||
dones = []
|
||||
infos = []
|
||||
for p in params:
|
||||
observation, reward, done, info = self.rollout(p)
|
||||
observations.append(observation)
|
||||
for p, c in zip(params, contexts):
|
||||
reward, info = self.rollout(p, c)
|
||||
rewards.append(reward)
|
||||
dones.append(done)
|
||||
infos.append(info)
|
||||
|
||||
return np.array(rewards), infos
|
||||
|
||||
def rollout(self, params, render=False):
|
||||
def rollout(self, params, context=None, render=False):
|
||||
""" This function generates a trajectory based on a DMP and then does the usual loop over reset and step"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DetPMPEnvWrapperPD(DetPMPEnvWrapperBase):
|
||||
"""
|
||||
Wrapper for gym environments which creates a trajectory in joint velocity space
|
||||
"""
|
||||
def rollout(self, params, render=False):
|
||||
params = np.reshape(params, newshape=(self.num_basis, self.num_dof)) * self.weights_scale
|
||||
self.pmp.set_weights(self.duration, params)
|
||||
t, des_pos, des_vel, des_acc = self.pmp.compute_trajectory(1 / self.dt, 1.)
|
||||
if self.zero_centered:
|
||||
des_pos += self.start_pos[None, :]
|
||||
|
||||
if self.post_traj_steps > 0:
|
||||
@ -67,14 +61,16 @@ class DetPMPEnvWrapperPD(DetPMPEnvWrapperBase):
|
||||
des_vel = np.vstack([des_vel, np.zeros(shape=(self.post_traj_steps, self.num_dof))])
|
||||
|
||||
self._trajectory = des_pos
|
||||
self._velocity = des_vel
|
||||
|
||||
rews = []
|
||||
infos = []
|
||||
|
||||
self.env.configure(context)
|
||||
self.env.reset()
|
||||
|
||||
for t, pos_vel in enumerate(zip(des_pos, des_vel)):
|
||||
ac = self.policy.get_action(self.env, pos_vel[0], pos_vel[1])
|
||||
ac = self.policy.get_action(pos_vel[0], pos_vel[1])
|
||||
obs, rew, done, info = self.env.step(ac)
|
||||
rews.append(rew)
|
||||
infos.append(info)
|
||||
@ -85,4 +81,5 @@ class DetPMPEnvWrapperPD(DetPMPEnvWrapperBase):
|
||||
|
||||
reward = np.sum(rews)
|
||||
|
||||
return obs, reward, done, info
|
||||
return reward, info
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
from alr_envs.utils.policies import get_policy_class
|
||||
from mp_lib.phase import ExpDecayPhaseGenerator
|
||||
from mp_lib.basis import DMPBasisGenerator
|
||||
from mp_lib import dmps
|
||||
@ -5,18 +6,6 @@ import numpy as np
|
||||
import gym
|
||||
|
||||
|
||||
def get_policy_class(policy_type):
|
||||
if policy_type == "motor":
|
||||
from alr_envs.utils.policies import PDController
|
||||
return PDController
|
||||
elif policy_type == "velocity":
|
||||
from alr_envs.utils.policies import VelController
|
||||
return VelController
|
||||
elif policy_type == "position":
|
||||
from alr_envs.utils.policies import PosController
|
||||
return PosController
|
||||
|
||||
|
||||
class DmpEnvWrapper(gym.Wrapper):
|
||||
def __init__(self,
|
||||
env,
|
||||
|
@ -35,3 +35,12 @@ class PDController(BaseController):
|
||||
des_vel = self.env.extend_des_vel(des_vel)
|
||||
trq = self.p_gains * (des_pos - cur_pos) + self.d_gains * (des_vel - cur_vel)
|
||||
return trq
|
||||
|
||||
|
||||
def get_policy_class(policy_type):
|
||||
if policy_type == "motor":
|
||||
return PDController
|
||||
elif policy_type == "velocity":
|
||||
return VelController
|
||||
elif policy_type == "position":
|
||||
return PosController
|
||||
|
Loading…
Reference in New Issue
Block a user