biac simple dmp env
This commit is contained in:
parent
448ebcde95
commit
4673a8c13b
@ -34,6 +34,7 @@ class BallInACupReward(alr_reward_fct.AlrReward):
|
||||
self.dists_final = []
|
||||
self.costs = []
|
||||
self.action_costs = []
|
||||
self.cup_angles = []
|
||||
|
||||
def compute_reward(self, action, sim, step, context=None):
|
||||
self.ball_id = sim.model._body_name2id["ball"]
|
||||
@ -51,6 +52,9 @@ class BallInACupReward(alr_reward_fct.AlrReward):
|
||||
self.dists.append(np.linalg.norm(goal_pos - ball_pos))
|
||||
self.dists_final.append(np.linalg.norm(goal_final_pos - ball_pos))
|
||||
self.ball_traj[step, :] = ball_pos
|
||||
cup_quat = np.copy(sim.data.body_xquat[sim.model._body_name2id["cup"]])
|
||||
self.cup_angles.append(np.arctan2(2 * (cup_quat[0] * cup_quat[1] + cup_quat[2] * cup_quat[3]),
|
||||
1 - 2 * (cup_quat[1]**2 + cup_quat[2]**2)))
|
||||
|
||||
action_cost = np.sum(np.square(action))
|
||||
self.action_costs.append(action_cost)
|
||||
@ -60,10 +64,14 @@ class BallInACupReward(alr_reward_fct.AlrReward):
|
||||
return reward, False, True
|
||||
|
||||
if step == self.sim_time - 1:
|
||||
min_dist = np.min(self.dists)
|
||||
t_min_dist = np.argmin(self.dists)
|
||||
angle_min_dist = self.cup_angles[t_min_dist]
|
||||
cost_angle = (angle_min_dist - np.pi / 2)**2
|
||||
|
||||
min_dist = self.dists[t_min_dist]
|
||||
dist_final = self.dists_final[-1]
|
||||
|
||||
cost = 0.5 * min_dist + 0.5 * dist_final
|
||||
cost = 0.5 * min_dist + 0.5 * dist_final + 0.01 * cost_angle
|
||||
reward = np.exp(-2 * cost) - 1e-3 * action_cost
|
||||
success = dist_final < 0.05 and ball_in_cup
|
||||
else:
|
||||
|
@ -1,4 +1,5 @@
|
||||
from alr_envs.utils.detpmp_env_wrapper import DetPMPEnvWrapper
|
||||
from alr_envs.utils.dmp_env_wrapper import DmpEnvWrapper
|
||||
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv
|
||||
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_simple import ALRBallInACupEnv as ALRBallInACupEnvSimple
|
||||
|
||||
@ -104,3 +105,38 @@ def make_simple_env(rank, seed=0):
|
||||
return env
|
||||
|
||||
return _init
|
||||
|
||||
|
||||
def make_simple_dmp_env(rank, seed=0):
|
||||
"""
|
||||
Utility function for multiprocessed env.
|
||||
|
||||
:param env_id: (str) the environment ID
|
||||
:param num_env: (int) the number of environments you wish to have in subprocesses
|
||||
:param seed: (int) the initial seed for RNG
|
||||
:param rank: (int) index of the subprocess
|
||||
:returns a function that generates an environment
|
||||
"""
|
||||
|
||||
def _init():
|
||||
_env = ALRBallInACupEnvSimple()
|
||||
|
||||
_env = DmpEnvWrapper(_env,
|
||||
num_dof=3,
|
||||
num_basis=5,
|
||||
duration=3.5,
|
||||
post_traj_time=4.5,
|
||||
bandwidth_factor=2.5,
|
||||
dt=_env.dt,
|
||||
learn_goal=False,
|
||||
alpha_phase=3,
|
||||
start_pos=_env.start_pos[1::2],
|
||||
final_pos=_env.start_pos[1::2],
|
||||
policy_type="motor",
|
||||
weights_scale=100,
|
||||
)
|
||||
|
||||
_env.seed(seed + rank)
|
||||
return _env
|
||||
|
||||
return _init
|
||||
|
@ -1,4 +1,4 @@
|
||||
from alr_envs.mujoco.ball_in_a_cup.utils import make_env, make_simple_env
|
||||
from alr_envs.mujoco.ball_in_a_cup.utils import make_env, make_simple_env, make_simple_dmp_env
|
||||
from alr_envs.utils.dmp_async_vec_env import DmpAsyncVectorEnv
|
||||
import numpy as np
|
||||
|
||||
@ -18,13 +18,13 @@ if __name__ == "__main__":
|
||||
# rewards, infos = vec_env(params)
|
||||
# print(rewards)
|
||||
#
|
||||
non_vec_env = make_simple_env(0, 0)()
|
||||
non_vec_env = make_simple_dmp_env(0, 0)()
|
||||
|
||||
# params = 0.5 * np.random.randn(dim)
|
||||
params = np.array([[11.90777345, 4.47656072, -2.49030537, 2.29386444, -3.5645336 ,
|
||||
2.97729181, 4.65224072, 3.72020235, 4.3658366 , -5.8489886 ,
|
||||
9.8045112 , 2.95405854, 4.56178261, 4.70669295, 4.55522522]])
|
||||
params = np.array([[-2.04114375, -2.62248565, 1.35999138, 4.29883804, 0.09143854,
|
||||
8.1752235 , -1.47063842, 0.60865483, -3.1697385 , 10.95458786,
|
||||
2.81887935, 3.6400505 , 1.43011501, -4.36044191, -3.66816722]])
|
||||
|
||||
out2 = non_vec_env.rollout(params, render=True)
|
||||
out2 = non_vec_env.rollout(params, render=False)
|
||||
|
||||
print(out2)
|
||||
|
Loading…
Reference in New Issue
Block a user