update with current settings

This commit is contained in:
Hongyi Zhou 2022-11-15 11:37:25 +01:00
parent b482c1bd89
commit f6128460ed
3 changed files with 14 additions and 8 deletions

View File

@ -259,7 +259,8 @@ for ctxt_dim in [2, 4]:
kwargs={ kwargs={
"ctxt_dim": ctxt_dim, "ctxt_dim": ctxt_dim,
'frame_skip': 4, 'frame_skip': 4,
'enable_wind': True 'enable_wind': False,
'enable_switching_goal': False,
} }
) )
@ -577,7 +578,6 @@ for _v in _versions:
kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try kwargs_dict_tt_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
kwargs_dict_tt_prodmp['basis_generator_kwargs']['pre_compute_length_factor'] = 5 kwargs_dict_tt_prodmp['basis_generator_kwargs']['pre_compute_length_factor'] = 5
kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3 kwargs_dict_tt_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
# kwargs_dict_tt_prodmp['black_box_kwargs']['duration'] = 4.
kwargs_dict_tt_prodmp['black_box_kwargs']['max_planning_times'] = 3 kwargs_dict_tt_prodmp['black_box_kwargs']['max_planning_times'] = 3
kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 50 == 0 kwargs_dict_tt_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 50 == 0
register( register(

View File

@ -71,4 +71,4 @@ class MPWrapper(RawInterfaceWrapper):
"is_success": [False], "is_success": [False],
'trajectory_length': 1, 'trajectory_length': 1,
"num_steps": [1] "num_steps": [1]
} }

View File

@ -43,6 +43,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self._racket_traj = [] self._racket_traj = []
self._enable_goal_switching = enable_switching_goal
MujocoEnv.__init__(self, MujocoEnv.__init__(self,
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
frame_skip=frame_skip, frame_skip=frame_skip,
@ -56,8 +58,6 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32) self.action_space = spaces.Box(low=-1, high=1, shape=(7,), dtype=np.float32)
self._enable_switching_goal = enable_switching_goal
# complex dynamics settings # complex dynamics settings
self.model.opt.density = 1.225 self.model.opt.density = 1.225
self.model.opt.viscosity = 2.27e-5 self.model.opt.viscosity = 2.27e-5
@ -79,6 +79,12 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
unstable_simulation = False unstable_simulation = False
if self._enable_goal_switching:
if self._steps == 45 and self.np_random.uniform(0, 1) < 0.5:
self._goal_pos[1] = -self._goal_pos[1]
self.model.body_pos[5] = np.concatenate([self._goal_pos, [0.77]])
mujoco.mj_forward(self.model, self.data)
for _ in range(self.frame_skip): for _ in range(self.frame_skip):
try: try:
self.do_simulation(action, 1) self.do_simulation(action, 1)
@ -180,9 +186,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self.data.joint("tar_x").qpos.copy(), self.data.joint("tar_x").qpos.copy(),
self.data.joint("tar_y").qpos.copy(), self.data.joint("tar_y").qpos.copy(),
self.data.joint("tar_z").qpos.copy(), self.data.joint("tar_z").qpos.copy(),
self.data.joint("tar_x").qvel.copy(), # self.data.joint("tar_x").qvel.copy(),
self.data.joint("tar_y").qvel.copy(), # self.data.joint("tar_y").qvel.copy(),
self.data.joint("tar_z").qvel.copy(), # self.data.joint("tar_z").qvel.copy(),
# self.data.body("target_ball").xvel.copy(), # self.data.body("target_ball").xvel.copy(),
self._goal_pos.copy(), self._goal_pos.copy(),
]) ])