holereacher pmp updates
This commit is contained in:
parent
a3d365033a
commit
bafa28edca
@ -99,7 +99,10 @@ class HoleReacher(gym.Env):
|
|||||||
if self._steps == 199:
|
if self._steps == 199:
|
||||||
reward = - np.linalg.norm(self.end_effector - self.bottom_center_of_hole) ** 2
|
reward = - np.linalg.norm(self.end_effector - self.bottom_center_of_hole) ** 2
|
||||||
else:
|
else:
|
||||||
reward = -self.collision_penalty
|
if self.collision_penalty != 0:
|
||||||
|
reward = -self.collision_penalty
|
||||||
|
else:
|
||||||
|
reward = - np.linalg.norm(self.end_effector - self.bottom_center_of_hole) ** 2
|
||||||
|
|
||||||
reward -= 5e-8 * np.sum(acc ** 2)
|
reward -= 5e-8 * np.sum(acc ** 2)
|
||||||
|
|
||||||
|
@ -134,18 +134,18 @@ def make_holereacher_env_pmp(rank, seed=0):
|
|||||||
hole_width=0.15,
|
hole_width=0.15,
|
||||||
hole_depth=1,
|
hole_depth=1,
|
||||||
hole_x=1,
|
hole_x=1,
|
||||||
collision_penalty=1000)
|
collision_penalty=100)
|
||||||
|
|
||||||
_env = DetPMPEnvWrapper(_env,
|
_env = DetPMPEnvWrapper(_env,
|
||||||
num_dof=5,
|
num_dof=5,
|
||||||
num_basis=5,
|
num_basis=5,
|
||||||
width=0.005,
|
width=0.025,
|
||||||
policy_type="velocity",
|
policy_type="velocity",
|
||||||
start_pos=_env.start_pos,
|
start_pos=_env.start_pos,
|
||||||
duration=2,
|
duration=2,
|
||||||
post_traj_time=0,
|
post_traj_time=0,
|
||||||
dt=_env.dt,
|
dt=_env.dt,
|
||||||
weights_scale=0.25,
|
weights_scale=0.2,
|
||||||
zero_start=True,
|
zero_start=True,
|
||||||
zero_goal=False
|
zero_goal=False
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from alr_envs.classic_control.utils import make_viapointreacher_env
|
from alr_envs.classic_control.utils import make_viapointreacher_env
|
||||||
from alr_envs.classic_control.utils import make_holereacher_env, make_holereacher_fix_goal_env
|
from alr_envs.classic_control.utils import make_holereacher_env, make_holereacher_fix_goal_env, make_holereacher_env_pmp
|
||||||
from alr_envs.utils.dmp_async_vec_env import DmpAsyncVectorEnv
|
from alr_envs.utils.dmp_async_vec_env import DmpAsyncVectorEnv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -8,22 +8,21 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
n_samples = 1
|
n_samples = 1
|
||||||
n_cpus = 4
|
n_cpus = 4
|
||||||
dim = 30
|
dim = 15
|
||||||
|
|
||||||
# env = DmpAsyncVectorEnv([make_viapointreacher_env(i) for i in range(n_cpus)],
|
# env = DmpAsyncVectorEnv([make_viapointreacher_env(i) for i in range(n_cpus)],
|
||||||
# n_samples=n_samples)
|
# n_samples=n_samples)
|
||||||
|
|
||||||
test_env = make_holereacher_env(0)()
|
test_env = make_holereacher_env_pmp(0)()
|
||||||
|
|
||||||
# params = np.random.randn(n_samples, dim)
|
# params = 1 * np.random.randn(dim)
|
||||||
params = np.array([ 0.57622273, 0.98294602, 1.48964131, 0.65430972,
|
params = np.array([[ -0.13106822, -0.66268577, -1.37025136, -1.34813613,
|
||||||
-0.26028221, 4.84693322, 1.77366128, 0.51080511,
|
-0.34040336, -1.41684643, 2.81882318, -1.93383471,
|
||||||
-2.38201107, -0.84990048, 1.02289828, 1.20675551,
|
-5.84213385, -3.8623558 , -1.31946267, 3.19346678,
|
||||||
0.38075566, -1.84282938, -3.48690172, 2.17434711,
|
-9.6581148 , -8.27402906, -0.42374776, -2.06852054,
|
||||||
-1.79285349, -1.7533641 , 0.62802966, 1.18928357,
|
7.21224904, -6.81061422, -9.54973119, -6.18636867,
|
||||||
0.2818753 , -3.27708291, -0.91761804, -0.38350967,
|
-6.82998929, 13.00398992, -18.28106949, -6.06678165,
|
||||||
2.25849139, 21.57786524, -14.38494647, -11.5380005 ,
|
2.79744735]])
|
||||||
-11.09529721, -0.39453533])
|
|
||||||
|
|
||||||
# params = np.hstack([50 * np.random.randn(n_samples, 25), np.tile(np.array([np.pi/2, -np.pi/4, -np.pi/4, -np.pi/4, -np.pi/4]), [n_samples, 1])])
|
# params = np.hstack([50 * np.random.randn(n_samples, 25), np.tile(np.array([np.pi/2, -np.pi/4, -np.pi/4, -np.pi/4, -np.pi/4]), [n_samples, 1])])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user