shorter TT simulation time

This commit is contained in:
Onur 2022-04-07 10:20:10 +02:00
parent 855f0f1c7b
commit 8ef00f7343
2 changed files with 7 additions and 4 deletions

View File

@ -417,7 +417,7 @@ for _v, cd in enumerate(ctxt_dim):
"num_dof": 7,
"num_basis": 2,
"duration": 1.25,
"post_traj_time": 4.5,
"post_traj_time": 1.5,
"policy_type": "motor",
"weights_scale": 1.0,
"zero_start": True,

View File

@ -10,9 +10,7 @@ from alr_envs.alr.mujoco.table_tennis.tt_reward import TT_Reward
#TODO: Check for simulation stability. Make sure the code runs even for sim crash
# MAX_EPISODE_STEPS = 1750
# MAX_EPISODE_STEPS = 1375
MAX_EPISODE_STEPS = 2875
MAX_EPISODE_STEPS = 1375 # (1.25 + 1.5)/0.002
BALL_NAME_CONTACT = "target_ball_contact"
BALL_NAME = "target_ball"
TABLE_NAME = 'table_tennis_table'
@ -58,6 +56,7 @@ class TTEnvGym(MujocoEnv, utils.EzPickle):
self.hit_ball = False
self.ball_contact_after_hit = False
self._ids_set = False
self.n_step = 0
super(TTEnvGym, self).__init__(model_path=model_path, frame_skip=1)
self.ball_id = self.sim.model._body_name2id[BALL_NAME] # find the proper -> not protected func.
self.ball_contact_id = self.sim.model._geom_name2id[BALL_NAME_CONTACT]
@ -67,6 +66,7 @@ class TTEnvGym(MujocoEnv, utils.EzPickle):
self.paddle_contact_id_2 = self.sim.model._geom_name2id[PADDLE_CONTACT_2_NAME] # check if we need both or only this
self.racket_id = self.sim.model._geom_name2id[RACKET_NAME]
def _set_ids(self):
self.ball_id = self.sim.model._body_name2id[BALL_NAME] # find the proper -> not protected func.
self.table_contact_id = self.sim.model._geom_name2id[TABLE_NAME]
@ -119,6 +119,7 @@ class TTEnvGym(MujocoEnv, utils.EzPickle):
self.sim.forward()
self.reward_func.reset(self.goal) # reset the reward function
self.n_step = 0
return self._get_obs()
def _contact_checker(self, id_1, id_2):
@ -166,6 +167,8 @@ class TTEnvGym(MujocoEnv, utils.EzPickle):
info = {"hit_ball": self.hit_ball,
"q_pos": np.copy(self.sim.data.qpos[:7]),
"ball_pos": np.copy(self.sim.data.qpos[7:])}
self.n_step += 1
print(self.n_step)
return ob, reward, done, info # might add some information here ....
def set_context(self, context):