updated table tennis and beerpong for promp usage
This commit is contained in:
		
							parent
							
								
									083e937e17
								
							
						
					
					
						commit
						a0af743585
					
				| @ -198,14 +198,19 @@ register( | |||||||
| 
 | 
 | ||||||
| ## Table Tennis | ## Table Tennis | ||||||
| register(id='TableTennis2DCtxt-v0', | register(id='TableTennis2DCtxt-v0', | ||||||
|          entry_point='alr_envs.alr.mujoco:TT_Env_Gym', |          entry_point='alr_envs.alr.mujoco:TTEnvGym', | ||||||
|          max_episode_steps=MAX_EPISODE_STEPS, |          max_episode_steps=MAX_EPISODE_STEPS, | ||||||
|          kwargs={'ctxt_dim':2}) |          kwargs={'ctxt_dim': 2}) | ||||||
|  | 
 | ||||||
|  | register(id='TableTennis2DCtxt-v1', | ||||||
|  |          entry_point='alr_envs.alr.mujoco:TTEnvGym', | ||||||
|  |          max_episode_steps=1750, | ||||||
|  |          kwargs={'ctxt_dim': 2, 'fixed_goal': True}) | ||||||
| 
 | 
 | ||||||
| register(id='TableTennis4DCtxt-v0', | register(id='TableTennis4DCtxt-v0', | ||||||
|          entry_point='alr_envs.alr.mujoco:TT_Env_Gym', |          entry_point='alr_envs.alr.mujoco:TTEnvGym', | ||||||
|          max_episode_steps=MAX_EPISODE_STEPS, |          max_episode_steps=MAX_EPISODE_STEPS, | ||||||
|          kwargs={'ctxt_dim':4}) |          kwargs={'ctxt_dim': 4}) | ||||||
| 
 | 
 | ||||||
| ## BeerPong | ## BeerPong | ||||||
| difficulties = ["simple", "intermediate", "hard", "hardest"] | difficulties = ["simple", "intermediate", "hard", "hardest"] | ||||||
| @ -369,13 +374,10 @@ register( | |||||||
|         "mp_kwargs": { |         "mp_kwargs": { | ||||||
|             "num_dof": 7, |             "num_dof": 7, | ||||||
|             "num_basis": 2, |             "num_basis": 2, | ||||||
|             "n_zero_bases": 2, |             "duration": 1, | ||||||
|             "duration": 0.5, |             "post_traj_time": 2, | ||||||
|             "post_traj_time": 2.5, |  | ||||||
|             # "width": 0.01, |  | ||||||
|             # "off": 0.01, |  | ||||||
|             "policy_type": "motor", |             "policy_type": "motor", | ||||||
|             "weights_scale": 0.08, |             "weights_scale": 0.2, | ||||||
|             "zero_start": True, |             "zero_start": True, | ||||||
|             "zero_goal": False, |             "zero_goal": False, | ||||||
|             "policy_kwargs": { |             "policy_kwargs": { | ||||||
| @ -388,22 +390,46 @@ register( | |||||||
| ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("BeerpongProMP-v0") | ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("BeerpongProMP-v0") | ||||||
| 
 | 
 | ||||||
| ## Table Tennis | ## Table Tennis | ||||||
|  | ctxt_dim = [2, 4] | ||||||
|  | for _v, cd in enumerate(ctxt_dim): | ||||||
|  |     _env_id = f'TableTennisProMP-v{_v}' | ||||||
|  |     register( | ||||||
|  |         id=_env_id, | ||||||
|  |         entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper', | ||||||
|  |         kwargs={ | ||||||
|  |             "name": "alr_envs:TableTennis{}DCtxt-v0".format(cd), | ||||||
|  |             "wrappers": [mujoco.table_tennis.MPWrapper], | ||||||
|  |             "mp_kwargs": { | ||||||
|  |                 "num_dof": 7, | ||||||
|  |                 "num_basis": 2, | ||||||
|  |                 "duration": 1.25, | ||||||
|  |                 "post_traj_time": 4.5, | ||||||
|  |                 "policy_type": "motor", | ||||||
|  |                 "weights_scale": 1.0, | ||||||
|  |                 "zero_start": True, | ||||||
|  |                 "zero_goal": False, | ||||||
|  |                 "policy_kwargs": { | ||||||
|  |                     "p_gains": 0.5*np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0]), | ||||||
|  |                     "d_gains": 0.5*np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1]) | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     ) | ||||||
|  |     ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) | ||||||
|  | 
 | ||||||
| register( | register( | ||||||
|     id='TableTennisProMP-v0', |     id='TableTennisProMP-v2', | ||||||
|     entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper', |     entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper', | ||||||
|     kwargs={ |     kwargs={ | ||||||
|         "name": "alr_envs:TableTennis4DCtxt-v0", |         "name": "alr_envs:TableTennis2DCtxt-v1", | ||||||
|         "wrappers": [mujoco.table_tennis.MPWrapper], |         "wrappers": [mujoco.table_tennis.MPWrapper], | ||||||
|         "mp_kwargs": { |         "mp_kwargs": { | ||||||
|             "num_dof": 7, |             "num_dof": 7, | ||||||
|             "num_basis": 2, |             "num_basis": 2, | ||||||
|             "n_zero_bases": 2, |             "duration": 1., | ||||||
|             "duration": 1.25, |             "post_traj_time": 2.5, | ||||||
|             "post_traj_time": 4.5, |  | ||||||
|             # "width": 0.01, |  | ||||||
|             # "off": 0.01, |  | ||||||
|             "policy_type": "motor", |             "policy_type": "motor", | ||||||
|             "weights_scale": 1.0, |             "weights_scale": 0.2, | ||||||
|             "zero_start": True, |             "zero_start": True, | ||||||
|             "zero_goal": False, |             "zero_goal": False, | ||||||
|             "policy_kwargs": { |             "policy_kwargs": { | ||||||
| @ -413,4 +439,4 @@ register( | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
| ) | ) | ||||||
| ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("TableTennisProMP-v0") | ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("TableTennisProMP-v2") | ||||||
|  | |||||||
| @ -14,8 +14,5 @@ | |||||||
| |`ViaPointReacherDMP-v0`| A DMP provides a trajectory for the `ViaPointReacher-v0` task. | 200 | 25 | |`ViaPointReacherDMP-v0`| A DMP provides a trajectory for the `ViaPointReacher-v0` task. | 200 | 25 | ||||||
| |`HoleReacherFixedGoalDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task with a fixed goal attractor. | 200 | 25 | |`HoleReacherFixedGoalDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task with a fixed goal attractor. | 200 | 25 | ||||||
| |`HoleReacherDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task. The goal attractor needs to be learned. | 200 | 30 | |`HoleReacherDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task. The goal attractor needs to be learned. | 200 | 30 | ||||||
| |`ALRBallInACupSimpleDMP-v0`| A DMP provides a trajectory for the `ALRBallInACupSimple-v0` task where only 3 joints are actuated. | 4000 | 15 |  | ||||||
| |`ALRBallInACupDMP-v0`| A DMP provides a trajectory for the `ALRBallInACup-v0` task. | 4000 | 35 |  | ||||||
| |`ALRBallInACupGoalDMP-v0`| A DMP provides a trajectory for the `ALRBallInACupGoal-v0` task. | 4000 | 35 | 3 |  | ||||||
| 
 | 
 | ||||||
| [//]:  |`HoleReacherProMPP-v0`| | [//]:  |`HoleReacherProMPP-v0`| | ||||||
| @ -2,5 +2,5 @@ from .reacher.alr_reacher import ALRReacherEnv | |||||||
| from .reacher.balancing import BalancingEnv | from .reacher.balancing import BalancingEnv | ||||||
| from .ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv | from .ball_in_a_cup.ball_in_a_cup import ALRBallInACupEnv | ||||||
| from .ball_in_a_cup.biac_pd import ALRBallInACupPDEnv | from .ball_in_a_cup.biac_pd import ALRBallInACupPDEnv | ||||||
| from .table_tennis.tt_gym import TT_Env_Gym | from .table_tennis.tt_gym import TTEnvGym | ||||||
| from .beerpong.beerpong import ALRBeerBongEnv | from .beerpong.beerpong import ALRBeerBongEnv | ||||||
| @ -27,10 +27,10 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle): | |||||||
|         self.ball_site_id = 0 |         self.ball_site_id = 0 | ||||||
|         self.ball_id = 11 |         self.ball_id = 11 | ||||||
| 
 | 
 | ||||||
|         self._release_step = 100  # time step of ball release |         self._release_step = 175  # time step of ball release | ||||||
| 
 | 
 | ||||||
|         self.sim_time = 4  # seconds |         self.sim_time = 3  # seconds | ||||||
|         self.ep_length = 600  # based on 5 seconds with dt = 0.005 int(self.sim_time / self.dt) |         self.ep_length = 600  # based on 3 seconds with dt = 0.005 int(self.sim_time / self.dt) | ||||||
|         self.cup_table_id = 10 |         self.cup_table_id = 10 | ||||||
| 
 | 
 | ||||||
|         if noisy: |         if noisy: | ||||||
| @ -143,7 +143,7 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle): | |||||||
|                                       q_vel=self.sim.data.qvel[0:7].ravel().copy(), |                                       q_vel=self.sim.data.qvel[0:7].ravel().copy(), | ||||||
|                                       ball_pos=ball_pos, |                                       ball_pos=ball_pos, | ||||||
|                                       ball_vel=ball_vel, |                                       ball_vel=ball_vel, | ||||||
|                                       is_success=success, |                                       success=success, | ||||||
|                                       is_collided=is_collided, sim_crash=crash) |                                       is_collided=is_collided, sim_crash=crash) | ||||||
| 
 | 
 | ||||||
|     def check_traj_in_joint_limits(self): |     def check_traj_in_joint_limits(self): | ||||||
| @ -171,7 +171,7 @@ class ALRBeerBongEnv(MujocoEnv, utils.EzPickle): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     env = ALRBeerBongEnv(reward_type="no_context", difficulty='hardest') |     env = ALRBeerBongEnv(reward_type="staged", difficulty='hardest') | ||||||
| 
 | 
 | ||||||
|     # env.configure(ctxt) |     # env.configure(ctxt) | ||||||
|     env.reset() |     env.reset() | ||||||
|  | |||||||
| @ -71,6 +71,7 @@ class BeerPongReward: | |||||||
| 
 | 
 | ||||||
|         goal_pos = env.sim.data.site_xpos[self.goal_id] |         goal_pos = env.sim.data.site_xpos[self.goal_id] | ||||||
|         ball_pos = env.sim.data.body_xpos[self.ball_id] |         ball_pos = env.sim.data.body_xpos[self.ball_id] | ||||||
|  |         ball_vel = env.sim.data.body_xvelp[self.ball_id] | ||||||
|         goal_final_pos = env.sim.data.site_xpos[self.goal_final_id] |         goal_final_pos = env.sim.data.site_xpos[self.goal_final_id] | ||||||
|         self.dists.append(np.linalg.norm(goal_pos - ball_pos)) |         self.dists.append(np.linalg.norm(goal_pos - ball_pos)) | ||||||
|         self.dists_final.append(np.linalg.norm(goal_final_pos - ball_pos)) |         self.dists_final.append(np.linalg.norm(goal_final_pos - ball_pos)) | ||||||
| @ -131,6 +132,7 @@ class BeerPongReward: | |||||||
|         infos["success"] = success |         infos["success"] = success | ||||||
|         infos["is_collided"] = self._is_collided |         infos["is_collided"] = self._is_collided | ||||||
|         infos["ball_pos"] = ball_pos.copy() |         infos["ball_pos"] = ball_pos.copy() | ||||||
|  |         infos["ball_vel"] = ball_vel.copy() | ||||||
|         infos["action_cost"] = 5e-4 * action_cost |         infos["action_cost"] = 5e-4 * action_cost | ||||||
| 
 | 
 | ||||||
|         return reward, infos |         return reward, infos | ||||||
|  | |||||||
| @ -81,32 +81,36 @@ class BeerPongReward: | |||||||
|         action_cost = np.sum(np.square(action)) |         action_cost = np.sum(np.square(action)) | ||||||
|         self.action_costs.append(action_cost) |         self.action_costs.append(action_cost) | ||||||
| 
 | 
 | ||||||
|  |         if not self.ball_table_contact: | ||||||
|  |             self.ball_table_contact = self._check_collision_single_objects(env.sim, self.ball_collision_id, | ||||||
|  |                                                                            self.table_collision_id) | ||||||
|  | 
 | ||||||
|         self._is_collided = self._check_collision_with_itself(env.sim, self.robot_collision_ids) |         self._is_collided = self._check_collision_with_itself(env.sim, self.robot_collision_ids) | ||||||
|         if env._steps == env.ep_length - 1 or self._is_collided: |         if env._steps == env.ep_length - 1 or self._is_collided: | ||||||
| 
 | 
 | ||||||
|             min_dist = np.min(self.dists) |             min_dist = np.min(self.dists) | ||||||
|             ball_table_bounce = self._check_collision_single_objects(env.sim, self.ball_collision_id, |             final_dist = self.dists_final[-1] | ||||||
|                                                                      self.table_collision_id) | 
 | ||||||
|             ball_cup_table_cont = self._check_collision_with_set_of_objects(env.sim, self.ball_collision_id, |  | ||||||
|                                                                             self.cup_collision_ids) |  | ||||||
|             ball_wall_cont = self._check_collision_single_objects(env.sim, self.ball_collision_id, |  | ||||||
|                                                                   self.wall_collision_id) |  | ||||||
|             ball_in_cup = self._check_collision_single_objects(env.sim, self.ball_collision_id, |             ball_in_cup = self._check_collision_single_objects(env.sim, self.ball_collision_id, | ||||||
|                                                                self.cup_table_collision_id) |                                                                self.cup_table_collision_id) | ||||||
|             if not ball_in_cup: |  | ||||||
|                 cost_offset = 2 |  | ||||||
|                 if not ball_cup_table_cont and not ball_table_bounce and not ball_wall_cont: |  | ||||||
|                     cost_offset += 2 |  | ||||||
|                 cost = cost_offset + min_dist ** 2 + 0.5 * self.dists_final[-1] ** 2 + 1e-7 * action_cost |  | ||||||
|             else: |  | ||||||
|                 cost = self.dists_final[-1] ** 2 + 1.5 * action_cost * 1e-7 |  | ||||||
| 
 | 
 | ||||||
|             reward = - 1 * cost - self.collision_penalty * int(self._is_collided) |             # encourage bounce before falling into cup | ||||||
|  |             if not ball_in_cup: | ||||||
|  |                 if not self.ball_table_contact: | ||||||
|  |                     reward = 0.2 * (1 - np.tanh(min_dist ** 2)) + 0.1 * (1 - np.tanh(final_dist ** 2)) | ||||||
|  |                 else: | ||||||
|  |                     reward = (1 - np.tanh(min_dist ** 2)) + 0.5 * (1 - np.tanh(final_dist ** 2)) | ||||||
|  |             else: | ||||||
|  |                 if not self.ball_table_contact: | ||||||
|  |                     reward = 2 * (1 - np.tanh(final_dist ** 2)) + 1 * (1 - np.tanh(min_dist ** 2)) + 1 | ||||||
|  |                 else: | ||||||
|  |                     reward = 2 * (1 - np.tanh(final_dist ** 2)) + 1 * (1 - np.tanh(min_dist ** 2)) + 3 | ||||||
|  | 
 | ||||||
|  |             # reward = - 1 * cost - self.collision_penalty * int(self._is_collided) | ||||||
|             success = ball_in_cup |             success = ball_in_cup | ||||||
|             crash = self._is_collided |             crash = self._is_collided | ||||||
|         else: |         else: | ||||||
|             reward = - 1e-7 * action_cost |             reward = - 1e-4 * action_cost | ||||||
|             cost = 0 |  | ||||||
|             success = False |             success = False | ||||||
|             crash = False |             crash = False | ||||||
| 
 | 
 | ||||||
| @ -115,26 +119,11 @@ class BeerPongReward: | |||||||
|         infos["is_collided"] = self._is_collided |         infos["is_collided"] = self._is_collided | ||||||
|         infos["ball_pos"] = ball_pos.copy() |         infos["ball_pos"] = ball_pos.copy() | ||||||
|         infos["ball_vel"] = ball_vel.copy() |         infos["ball_vel"] = ball_vel.copy() | ||||||
|         infos["action_cost"] = 5e-4 * action_cost |         infos["action_cost"] = action_cost | ||||||
|         infos["task_cost"] = cost |         infos["task_reward"] = reward | ||||||
| 
 | 
 | ||||||
|         return reward, infos |         return reward, infos | ||||||
| 
 | 
 | ||||||
|     def get_cost_offset(self): |  | ||||||
|         if self.ball_ground_contact: |  | ||||||
|             return 200 |  | ||||||
| 
 |  | ||||||
|         if not self.ball_table_contact: |  | ||||||
|             return 100 |  | ||||||
| 
 |  | ||||||
|         if not self.ball_in_cup: |  | ||||||
|             return 50 |  | ||||||
| 
 |  | ||||||
|         if self.ball_in_cup and self.ball_cup_contact and not self.noisy_bp: |  | ||||||
|             return 10 |  | ||||||
| 
 |  | ||||||
|         return 0 |  | ||||||
| 
 |  | ||||||
|     def _check_collision_single_objects(self, sim, id_1, id_2): |     def _check_collision_single_objects(self, sim, id_1, id_2): | ||||||
|         for coni in range(0, sim.data.ncon): |         for coni in range(0, sim.data.ncon): | ||||||
|             con = sim.data.contact[coni] |             con = sim.data.contact[coni] | ||||||
|  | |||||||
| @ -6,8 +6,6 @@ from gym.envs.mujoco import MujocoEnv | |||||||
| 
 | 
 | ||||||
| class ALRBeerpongEnv(MujocoEnv, utils.EzPickle): | class ALRBeerpongEnv(MujocoEnv, utils.EzPickle): | ||||||
|     def __init__(self, n_substeps=4, apply_gravity_comp=True, reward_function=None): |     def __init__(self, n_substeps=4, apply_gravity_comp=True, reward_function=None): | ||||||
|         utils.EzPickle.__init__(**locals()) |  | ||||||
| 
 |  | ||||||
|         self._steps = 0 |         self._steps = 0 | ||||||
| 
 | 
 | ||||||
|         self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", |         self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", | ||||||
| @ -28,15 +26,13 @@ class ALRBeerpongEnv(MujocoEnv, utils.EzPickle): | |||||||
| 
 | 
 | ||||||
|         self.context = None |         self.context = None | ||||||
| 
 | 
 | ||||||
|         MujocoEnv.__init__(self, model_path=self.xml_path, frame_skip=n_substeps) |  | ||||||
| 
 |  | ||||||
|         # alr_mujoco_env.AlrMujocoEnv.__init__(self, |         # alr_mujoco_env.AlrMujocoEnv.__init__(self, | ||||||
|         #                                      self.xml_path, |         #                                      self.xml_path, | ||||||
|         #                                      apply_gravity_comp=apply_gravity_comp, |         #                                      apply_gravity_comp=apply_gravity_comp, | ||||||
|         #                                      n_substeps=n_substeps) |         #                                      n_substeps=n_substeps) | ||||||
| 
 | 
 | ||||||
|         self.sim_time = 8  # seconds |         self.sim_time = 8  # seconds | ||||||
|         self.sim_steps = int(self.sim_time / self.dt) |         # self.sim_steps = int(self.sim_time / self.dt) | ||||||
|         if reward_function is None: |         if reward_function is None: | ||||||
|             from alr_envs.alr.mujoco.beerpong.beerpong_reward_simple import BeerpongReward |             from alr_envs.alr.mujoco.beerpong.beerpong_reward_simple import BeerpongReward | ||||||
|             reward_function = BeerpongReward |             reward_function = BeerpongReward | ||||||
| @ -46,6 +42,9 @@ class ALRBeerpongEnv(MujocoEnv, utils.EzPickle): | |||||||
|         self.cup_table_id = self.sim.model._body_name2id["cup_table"] |         self.cup_table_id = self.sim.model._body_name2id["cup_table"] | ||||||
|         # self.bounce_table_id = self.sim.model._body_name2id["bounce_table"] |         # self.bounce_table_id = self.sim.model._body_name2id["bounce_table"] | ||||||
| 
 | 
 | ||||||
|  |         MujocoEnv.__init__(self, model_path=self.xml_path, frame_skip=n_substeps) | ||||||
|  |         utils.EzPickle.__init__(self) | ||||||
|  | 
 | ||||||
|     @property |     @property | ||||||
|     def current_pos(self): |     def current_pos(self): | ||||||
|         return self.sim.data.qpos[0:7].copy() |         return self.sim.data.qpos[0:7].copy() | ||||||
| @ -90,7 +89,7 @@ class ALRBeerpongEnv(MujocoEnv, utils.EzPickle): | |||||||
|         reward_ctrl = - np.square(a).sum() |         reward_ctrl = - np.square(a).sum() | ||||||
|         action_cost = np.sum(np.square(a)) |         action_cost = np.sum(np.square(a)) | ||||||
| 
 | 
 | ||||||
|         crash = self.do_simulation(a) |         crash = self.do_simulation(a, self.frame_skip) | ||||||
|         joint_cons_viol = self.check_traj_in_joint_limits() |         joint_cons_viol = self.check_traj_in_joint_limits() | ||||||
| 
 | 
 | ||||||
|         self._q_pos.append(self.sim.data.qpos[0:7].ravel().copy()) |         self._q_pos.append(self.sim.data.qpos[0:7].ravel().copy()) | ||||||
|  | |||||||
| @ -10,7 +10,7 @@ from alr_envs.alr.mujoco.table_tennis.tt_reward import TT_Reward | |||||||
| 
 | 
 | ||||||
| #TODO: Check for simulation stability. Make sure the code runs even for sim crash | #TODO: Check for simulation stability. Make sure the code runs even for sim crash | ||||||
| 
 | 
 | ||||||
| MAX_EPISODE_STEPS = 1375 | MAX_EPISODE_STEPS = 2875 | ||||||
| BALL_NAME_CONTACT = "target_ball_contact" | BALL_NAME_CONTACT = "target_ball_contact" | ||||||
| BALL_NAME = "target_ball" | BALL_NAME = "target_ball" | ||||||
| TABLE_NAME = 'table_tennis_table' | TABLE_NAME = 'table_tennis_table' | ||||||
| @ -22,15 +22,20 @@ RACKET_NAME = 'bat' | |||||||
| CONTEXT_RANGE_BOUNDS_2DIM = np.array([[-1.2, -0.6], [-0.2, 0.0]]) | CONTEXT_RANGE_BOUNDS_2DIM = np.array([[-1.2, -0.6], [-0.2, 0.0]]) | ||||||
| CONTEXT_RANGE_BOUNDS_4DIM = np.array([[-1.35, -0.75, -1.25, -0.75], [-0.1, 0.75, -0.1, 0.75]]) | CONTEXT_RANGE_BOUNDS_4DIM = np.array([[-1.35, -0.75, -1.25, -0.75], [-0.1, 0.75, -0.1, 0.75]]) | ||||||
| 
 | 
 | ||||||
| class TT_Env_Gym(MujocoEnv, utils.EzPickle): |  | ||||||
| 
 | 
 | ||||||
|     def __init__(self, ctxt_dim=2): | class TTEnvGym(MujocoEnv, utils.EzPickle): | ||||||
|  | 
 | ||||||
|  |     def __init__(self, ctxt_dim=2, fixed_goal=False): | ||||||
|         model_path = os.path.join(os.path.dirname(__file__), "xml", 'table_tennis_env.xml') |         model_path = os.path.join(os.path.dirname(__file__), "xml", 'table_tennis_env.xml') | ||||||
| 
 | 
 | ||||||
|         self.ctxt_dim = ctxt_dim |         self.ctxt_dim = ctxt_dim | ||||||
|  |         self.fixed_goal = fixed_goal | ||||||
|         if ctxt_dim == 2: |         if ctxt_dim == 2: | ||||||
|             self.context_range_bounds = CONTEXT_RANGE_BOUNDS_2DIM |             self.context_range_bounds = CONTEXT_RANGE_BOUNDS_2DIM | ||||||
|             self.goal = np.zeros(3)  # 2 x,y + 1z |             if self.fixed_goal: | ||||||
|  |                 self.goal = np.array([-1, -0.1, 0]) | ||||||
|  |             else: | ||||||
|  |                 self.goal = np.zeros(3)  # 2 x,y + 1z | ||||||
|         elif ctxt_dim == 4: |         elif ctxt_dim == 4: | ||||||
|             self.context_range_bounds = CONTEXT_RANGE_BOUNDS_4DIM |             self.context_range_bounds = CONTEXT_RANGE_BOUNDS_4DIM | ||||||
|             self.goal = np.zeros(3) |             self.goal = np.zeros(3) | ||||||
| @ -47,10 +52,10 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle): | |||||||
| 
 | 
 | ||||||
|         self.reward_func = TT_Reward(self.ctxt_dim) |         self.reward_func = TT_Reward(self.ctxt_dim) | ||||||
|         self.ball_landing_pos = None |         self.ball_landing_pos = None | ||||||
|         self.hited_ball = False |         self.hit_ball = False | ||||||
|         self.ball_contact_after_hit = False |         self.ball_contact_after_hit = False | ||||||
|         self._ids_set = False |         self._ids_set = False | ||||||
|         super(TT_Env_Gym, self).__init__(model_path=model_path, frame_skip=1) |         super(TTEnvGym, self).__init__(model_path=model_path, frame_skip=1) | ||||||
|         self.ball_id = self.sim.model._body_name2id[BALL_NAME]  # find the proper -> not protected func. |         self.ball_id = self.sim.model._body_name2id[BALL_NAME]  # find the proper -> not protected func. | ||||||
|         self.ball_contact_id = self.sim.model._geom_name2id[BALL_NAME_CONTACT] |         self.ball_contact_id = self.sim.model._geom_name2id[BALL_NAME_CONTACT] | ||||||
|         self.table_contact_id = self.sim.model._geom_name2id[TABLE_NAME] |         self.table_contact_id = self.sim.model._geom_name2id[TABLE_NAME] | ||||||
| @ -77,15 +82,18 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle): | |||||||
|         return obs |         return obs | ||||||
| 
 | 
 | ||||||
|     def sample_context(self): |     def sample_context(self): | ||||||
|         return np.random.uniform(self.context_range_bounds[0], self.context_range_bounds[1], size=self.ctxt_dim) |         return self.np_random.uniform(self.context_range_bounds[0], self.context_range_bounds[1], size=self.ctxt_dim) | ||||||
| 
 | 
 | ||||||
|     def reset_model(self): |     def reset_model(self): | ||||||
|         self.set_state(self.init_qpos_tt, self.init_qvel_tt)    # reset to initial sim state |         self.set_state(self.init_qpos_tt, self.init_qvel_tt)    # reset to initial sim state | ||||||
|         self.time_steps = 0 |         self.time_steps = 0 | ||||||
|         self.ball_landing_pos = None |         self.ball_landing_pos = None | ||||||
|         self.hited_ball = False |         self.hit_ball = False | ||||||
|         self.ball_contact_after_hit = False |         self.ball_contact_after_hit = False | ||||||
|         self.goal = self.sample_context()[:2] |         if self.fixed_goal: | ||||||
|  |             self.goal = self.goal[:2] | ||||||
|  |         else: | ||||||
|  |             self.goal = self.sample_context()[:2] | ||||||
|         if self.ctxt_dim == 2: |         if self.ctxt_dim == 2: | ||||||
|             initial_ball_state = ball_init(random=False)  # fixed velocity, fixed position |             initial_ball_state = ball_init(random=False)  # fixed velocity, fixed position | ||||||
|         elif self.ctxt_dim == 4: |         elif self.ctxt_dim == 4: | ||||||
| @ -122,12 +130,12 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle): | |||||||
|         if not self._ids_set: |         if not self._ids_set: | ||||||
|             self._set_ids() |             self._set_ids() | ||||||
|         done = False |         done = False | ||||||
|         episode_end = False if self.time_steps+1<MAX_EPISODE_STEPS else True |         episode_end = False if self.time_steps + 1 < MAX_EPISODE_STEPS else True | ||||||
|         if not self.hited_ball: |         if not self.hit_ball: | ||||||
|             self.hited_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_1) # check for one side |             self.hit_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_1) # check for one side | ||||||
|             if not self.hited_ball: |             if not self.hit_ball: | ||||||
|                 self.hited_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_2) # check for other side |                 self.hit_ball = self._contact_checker(self.ball_contact_id, self.paddle_contact_id_2) # check for other side | ||||||
|         if self.hited_ball: |         if self.hit_ball: | ||||||
|             if not self.ball_contact_after_hit: |             if not self.ball_contact_after_hit: | ||||||
|                 if self._contact_checker(self.ball_contact_id, self.floor_contact_id):  # first check contact with floor |                 if self._contact_checker(self.ball_contact_id, self.floor_contact_id):  # first check contact with floor | ||||||
|                     self.ball_contact_after_hit = True |                     self.ball_contact_after_hit = True | ||||||
| @ -140,7 +148,7 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle): | |||||||
|         if self.ball_landing_pos is not None: |         if self.ball_landing_pos is not None: | ||||||
|             done = True |             done = True | ||||||
|             episode_end =True |             episode_end =True | ||||||
|         reward = self.reward_func.get_reward(episode_end, c_ball_pos, racket_pos, self.hited_ball, self.ball_landing_pos) |         reward = self.reward_func.get_reward(episode_end, c_ball_pos, racket_pos, self.hit_ball, self.ball_landing_pos) | ||||||
|         self.time_steps += 1 |         self.time_steps += 1 | ||||||
|         # gravity compensation on joints: |         # gravity compensation on joints: | ||||||
|         #action += self.sim.data.qfrc_bias[:7].copy() |         #action += self.sim.data.qfrc_bias[:7].copy() | ||||||
| @ -151,7 +159,7 @@ class TT_Env_Gym(MujocoEnv, utils.EzPickle): | |||||||
|             done = True |             done = True | ||||||
|             reward = -25 |             reward = -25 | ||||||
|         ob = self._get_obs() |         ob = self._get_obs() | ||||||
|         return ob, reward, done, {"hit_ball":self.hited_ball}# might add some information here .... |         return ob, reward, done, {"hit_ball": self.hit_ball}  # might add some information here .... | ||||||
| 
 | 
 | ||||||
|     def set_context(self, context): |     def set_context(self, context): | ||||||
|         old_state = self.sim.get_state() |         old_state = self.sim.get_state() | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user