From 8ef00f73432d64ca600263b5ad2840ff454433d6 Mon Sep 17 00:00:00 2001 From: Onur Date: Thu, 7 Apr 2022 10:20:10 +0200 Subject: [PATCH] shorter TT simulation time --- alr_envs/alr/__init__.py | 2 +- alr_envs/alr/mujoco/table_tennis/tt_gym.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/alr_envs/alr/__init__.py b/alr_envs/alr/__init__.py index 917e102..931e3bb 100644 --- a/alr_envs/alr/__init__.py +++ b/alr_envs/alr/__init__.py @@ -417,7 +417,7 @@ for _v, cd in enumerate(ctxt_dim): "num_dof": 7, "num_basis": 2, "duration": 1.25, - "post_traj_time": 4.5, + "post_traj_time": 1.5, "policy_type": "motor", "weights_scale": 1.0, "zero_start": True, diff --git a/alr_envs/alr/mujoco/table_tennis/tt_gym.py b/alr_envs/alr/mujoco/table_tennis/tt_gym.py index c93cd26..7079a1b 100644 --- a/alr_envs/alr/mujoco/table_tennis/tt_gym.py +++ b/alr_envs/alr/mujoco/table_tennis/tt_gym.py @@ -10,9 +10,7 @@ from alr_envs.alr.mujoco.table_tennis.tt_reward import TT_Reward #TODO: Check for simulation stability. Make sure the code runs even for sim crash -# MAX_EPISODE_STEPS = 1750 -# MAX_EPISODE_STEPS = 1375 -MAX_EPISODE_STEPS = 2875 +MAX_EPISODE_STEPS = 1375 # (1.25 + 1.5)/0.002 BALL_NAME_CONTACT = "target_ball_contact" BALL_NAME = "target_ball" TABLE_NAME = 'table_tennis_table' @@ -58,6 +56,7 @@ class TTEnvGym(MujocoEnv, utils.EzPickle): self.hit_ball = False self.ball_contact_after_hit = False self._ids_set = False + self.n_step = 0 super(TTEnvGym, self).__init__(model_path=model_path, frame_skip=1) self.ball_id = self.sim.model._body_name2id[BALL_NAME] # find the proper -> not protected func. self.ball_contact_id = self.sim.model._geom_name2id[BALL_NAME_CONTACT] @@ -67,6 +66,7 @@ class TTEnvGym(MujocoEnv, utils.EzPickle): self.paddle_contact_id_2 = self.sim.model._geom_name2id[PADDLE_CONTACT_2_NAME] # check if we need both or only this self.racket_id = self.sim.model._geom_name2id[RACKET_NAME] + def _set_ids(self): self.ball_id = self.sim.model._body_name2id[BALL_NAME] # find the proper -> not protected func. self.table_contact_id = self.sim.model._geom_name2id[TABLE_NAME] @@ -119,6 +119,7 @@ class TTEnvGym(MujocoEnv, utils.EzPickle): self.sim.forward() self.reward_func.reset(self.goal) # reset the reward function + self.n_step = 0 return self._get_obs() def _contact_checker(self, id_1, id_2): @@ -166,6 +167,8 @@ class TTEnvGym(MujocoEnv, utils.EzPickle): info = {"hit_ball": self.hit_ball, "q_pos": np.copy(self.sim.data.qpos[:7]), "ball_pos": np.copy(self.sim.data.qpos[7:])} + self.n_step += 1 + print(self.n_step) return ob, reward, done, info # might add some information here .... def set_context(self, context):