shorter TT simulation time
This commit is contained in:
parent
855f0f1c7b
commit
8ef00f7343
@ -417,7 +417,7 @@ for _v, cd in enumerate(ctxt_dim):
|
|||||||
"num_dof": 7,
|
"num_dof": 7,
|
||||||
"num_basis": 2,
|
"num_basis": 2,
|
||||||
"duration": 1.25,
|
"duration": 1.25,
|
||||||
"post_traj_time": 4.5,
|
"post_traj_time": 1.5,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 1.0,
|
"weights_scale": 1.0,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
|
@ -10,9 +10,7 @@ from alr_envs.alr.mujoco.table_tennis.tt_reward import TT_Reward
|
|||||||
|
|
||||||
#TODO: Check for simulation stability. Make sure the code runs even for sim crash
|
#TODO: Check for simulation stability. Make sure the code runs even for sim crash
|
||||||
|
|
||||||
# MAX_EPISODE_STEPS = 1750
|
MAX_EPISODE_STEPS = 1375 # (1.25 + 1.5)/0.002
|
||||||
# MAX_EPISODE_STEPS = 1375
|
|
||||||
MAX_EPISODE_STEPS = 2875
|
|
||||||
BALL_NAME_CONTACT = "target_ball_contact"
|
BALL_NAME_CONTACT = "target_ball_contact"
|
||||||
BALL_NAME = "target_ball"
|
BALL_NAME = "target_ball"
|
||||||
TABLE_NAME = 'table_tennis_table'
|
TABLE_NAME = 'table_tennis_table'
|
||||||
@ -58,6 +56,7 @@ class TTEnvGym(MujocoEnv, utils.EzPickle):
|
|||||||
self.hit_ball = False
|
self.hit_ball = False
|
||||||
self.ball_contact_after_hit = False
|
self.ball_contact_after_hit = False
|
||||||
self._ids_set = False
|
self._ids_set = False
|
||||||
|
self.n_step = 0
|
||||||
super(TTEnvGym, self).__init__(model_path=model_path, frame_skip=1)
|
super(TTEnvGym, self).__init__(model_path=model_path, frame_skip=1)
|
||||||
self.ball_id = self.sim.model._body_name2id[BALL_NAME] # find the proper -> not protected func.
|
self.ball_id = self.sim.model._body_name2id[BALL_NAME] # find the proper -> not protected func.
|
||||||
self.ball_contact_id = self.sim.model._geom_name2id[BALL_NAME_CONTACT]
|
self.ball_contact_id = self.sim.model._geom_name2id[BALL_NAME_CONTACT]
|
||||||
@ -67,6 +66,7 @@ class TTEnvGym(MujocoEnv, utils.EzPickle):
|
|||||||
self.paddle_contact_id_2 = self.sim.model._geom_name2id[PADDLE_CONTACT_2_NAME] # check if we need both or only this
|
self.paddle_contact_id_2 = self.sim.model._geom_name2id[PADDLE_CONTACT_2_NAME] # check if we need both or only this
|
||||||
self.racket_id = self.sim.model._geom_name2id[RACKET_NAME]
|
self.racket_id = self.sim.model._geom_name2id[RACKET_NAME]
|
||||||
|
|
||||||
|
|
||||||
def _set_ids(self):
|
def _set_ids(self):
|
||||||
self.ball_id = self.sim.model._body_name2id[BALL_NAME] # find the proper -> not protected func.
|
self.ball_id = self.sim.model._body_name2id[BALL_NAME] # find the proper -> not protected func.
|
||||||
self.table_contact_id = self.sim.model._geom_name2id[TABLE_NAME]
|
self.table_contact_id = self.sim.model._geom_name2id[TABLE_NAME]
|
||||||
@ -119,6 +119,7 @@ class TTEnvGym(MujocoEnv, utils.EzPickle):
|
|||||||
self.sim.forward()
|
self.sim.forward()
|
||||||
|
|
||||||
self.reward_func.reset(self.goal) # reset the reward function
|
self.reward_func.reset(self.goal) # reset the reward function
|
||||||
|
self.n_step = 0
|
||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def _contact_checker(self, id_1, id_2):
|
def _contact_checker(self, id_1, id_2):
|
||||||
@ -166,6 +167,8 @@ class TTEnvGym(MujocoEnv, utils.EzPickle):
|
|||||||
info = {"hit_ball": self.hit_ball,
|
info = {"hit_ball": self.hit_ball,
|
||||||
"q_pos": np.copy(self.sim.data.qpos[:7]),
|
"q_pos": np.copy(self.sim.data.qpos[:7]),
|
||||||
"ball_pos": np.copy(self.sim.data.qpos[7:])}
|
"ball_pos": np.copy(self.sim.data.qpos[7:])}
|
||||||
|
self.n_step += 1
|
||||||
|
print(self.n_step)
|
||||||
return ob, reward, done, info # might add some information here ....
|
return ob, reward, done, info # might add some information here ....
|
||||||
|
|
||||||
def set_context(self, context):
|
def set_context(self, context):
|
||||||
|
Loading…
Reference in New Issue
Block a user