fix tt issues -> context + traj.length
This commit is contained in:
		
							parent
							
								
									66be0b1e02
								
							
						
					
					
						commit
						2a27f59e50
					
				| @ -236,6 +236,17 @@ register( | ||||
|         } | ||||
|     ) | ||||
| 
 | ||||
| # Beerpong devel big table | ||||
| register( | ||||
|         id='ALRBeerPong-v3', | ||||
|         entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv', | ||||
|         max_episode_steps=600, | ||||
|         kwargs={ | ||||
|             "rndm_goal": True, | ||||
|             "cup_goal_pos": [-0.3, -1.2] | ||||
|         } | ||||
|     ) | ||||
| 
 | ||||
| # Motion Primitive Environments | ||||
| 
 | ||||
| ## Simple Reacher | ||||
| @ -402,6 +413,32 @@ for _v in _versions: | ||||
|     ) | ||||
|     ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) | ||||
| 
 | ||||
| ## Beerpong- Big table devel | ||||
| 
 | ||||
| register( | ||||
|         id='BeerpongProMP-v3', | ||||
|         entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper', | ||||
|         kwargs={ | ||||
|             "name": f"alr_envs:ALRBeerPong-v3", | ||||
|             "wrappers": [mujoco.beerpong.MPWrapper], | ||||
|             "mp_kwargs": { | ||||
|                 "num_dof": 7, | ||||
|                 "num_basis": 5, | ||||
|                 "duration": 1, | ||||
|                 "post_traj_time": 2, | ||||
|                 "policy_type": "motor", | ||||
|                 "weights_scale": 1, | ||||
|                 "zero_start": True, | ||||
|                 "zero_goal": False, | ||||
|                 "policy_kwargs": { | ||||
|                     "p_gains": np.array([       1.5,   5,   2.55,    3,   2.,    2,   1.25]), | ||||
|                     "d_gains": np.array([0.02333333, 0.1, 0.0625, 0.08, 0.03, 0.03, 0.0125]) | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     ) | ||||
| ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append('BeerpongProMP-v3') | ||||
| 
 | ||||
| ## Table Tennis | ||||
| ctxt_dim = [2, 4] | ||||
| for _v, cd in enumerate(ctxt_dim): | ||||
| @ -416,7 +453,7 @@ for _v, cd in enumerate(ctxt_dim): | ||||
|                 "num_dof": 7, | ||||
|                 "num_basis": 2, | ||||
|                 "duration": 1.25, | ||||
|                 "post_traj_time": 4.5, | ||||
|                 "post_traj_time": 1.5, | ||||
|                 "policy_type": "motor", | ||||
|                 "weights_scale": 1.0, | ||||
|                 "zero_start": True, | ||||
|  | ||||
| @ -132,18 +132,19 @@ | ||||
|                 </body> | ||||
|             </body> | ||||
|         </body> | ||||
|         <body name="table_body" pos="0 -1.85 0.4025"> | ||||
|             <geom name="table" type="box" size="0.4 0.6 0.4" rgba="0.8 0.655 0.45 1" solimp="0.999 0.999 0.001" | ||||
|          <body name="table_body" pos="0 -2.8 0.4025"> | ||||
|             <geom name="table" type="box" size="1.5 1.5 0.4" rgba="0.8 0.655 0.45 1" solimp="0.999 0.999 0.001" | ||||
|                   solref="-10000 -100"/> | ||||
|             <geom name="table_contact_geom" type="box" size="0.4 0.6 0.1" pos="0 0 0.31" rgba="1.4 0.8 0.45 1" solimp="0.999 0.999 0.001" | ||||
|             <geom name="table_contact_geom" type="box" size="1.5 1.5 0.1" pos="0 0 0.31" rgba="1.4 0.8 0.45 1" solimp="0.999 0.999 0.001" | ||||
|                   solref="-10000 -100"/> | ||||
|         </body> | ||||
|         <geom name="table_robot" type="box" size="0.1 0.1 0.3" pos="0 0.00 0.3025" rgba="0.8 0.655 0.45 1" solimp="0.999 0.999 0.001" | ||||
|                   solref="-10000 -100"/> | ||||
|         <geom name="wall" type="box" quat="1 0 0 0" size="0.4 0.04 1.1" pos="0. -2.45 1.1" rgba="0.8 0.655 0.45 1" solimp="0.999 0.999 0.001" | ||||
|         <geom name="wall" type="box" quat="1 0 0 0" size="1.5 0.04 1.4" pos="0. -4.3 1.4" rgba="0.8 0.655 0.45 1" solimp="0.999 0.999 0.001" | ||||
|                   solref="-10000 -100"/> | ||||
| 
 | ||||
|         <body name="cup_table" pos="0.32 -1.55 0.84" quat="0.7071068 0.7071068 0 0"> | ||||
| <!--        <body name="cup_table" pos="0.32 -1.55 0.84" quat="0.7071068 0.7071068 0 0">--> | ||||
|         <body name="cup_table" pos="1.42 -1.25 0.84" quat="0.7071068 0.7071068 0 0"> | ||||
|             <inertial pos="-3.75236e-10 8.27811e-05 0.0947015" quat="0.999945 -0.0104888 0 0" mass="10.132" diaginertia="0.000285643 0.000270485 9.65696e-05" /> | ||||
|             <geom priority= "1" name="cup_geom_table3" pos="0 0.1 0.001" euler="-1.57 0 0" solref="-10000 -100" type="mesh" mesh="cup3_table" mass="10"/> | ||||
|             <geom priority= "1" name="cup_geom_table4" pos="0 0.1 0.001" euler="-1.57 0 0" solref="-10000 -100" type="mesh" mesh="cup4_table" mass="10"/> | ||||
|  | ||||
| @ -7,8 +7,11 @@ from gym.envs.mujoco import MujocoEnv | ||||
| from alr_envs.alr.mujoco.beerpong.beerpong_reward_staged import BeerPongReward | ||||
| 
 | ||||
| 
 | ||||
| CUP_POS_MIN = np.array([-0.32, -2.2]) | ||||
| CUP_POS_MAX = np.array([0.32, -1.2]) | ||||
| # CUP_POS_MIN = np.array([-0.32, -2.2]) | ||||
| # CUP_POS_MAX = np.array([0.32, -1.2]) | ||||
| 
 | ||||
| CUP_POS_MIN = np.array([-1.42, -4.05]) | ||||
| CUP_POS_MAX = np.array([1.42, -1.25]) | ||||
| 
 | ||||
| 
 | ||||
| class ALRBeerBongEnv(MujocoEnv, utils.EzPickle): | ||||
|  | ||||
| @ -11,9 +11,9 @@ class MPWrapper(MPEnvWrapper): | ||||
|     def active_obs(self): | ||||
|         # TODO: @Max Filter observations correctly | ||||
|         return np.hstack([ | ||||
|             [True] * 7,  # Joint Pos | ||||
|             [True] * 3, # Ball pos | ||||
|             [True] * 3  # goal pos | ||||
|             [False] * 7,  # Joint Pos | ||||
|             [True] * 2, # Ball pos | ||||
|             [True] * 2  # goal pos | ||||
|         ]) | ||||
| 
 | ||||
|     @property | ||||
|  | ||||
| @ -10,7 +10,8 @@ from alr_envs.alr.mujoco.table_tennis.tt_reward import TT_Reward | ||||
| 
 | ||||
| #TODO: Check for simulation stability. Make sure the code runs even for sim crash | ||||
| 
 | ||||
| MAX_EPISODE_STEPS = 1750 | ||||
| # MAX_EPISODE_STEPS = 1750 | ||||
| MAX_EPISODE_STEPS = 1375 | ||||
| BALL_NAME_CONTACT = "target_ball_contact" | ||||
| BALL_NAME = "target_ball" | ||||
| TABLE_NAME = 'table_tennis_table' | ||||
| @ -76,10 +77,11 @@ class TTEnvGym(MujocoEnv, utils.EzPickle): | ||||
|         self._ids_set = True | ||||
| 
 | ||||
|     def _get_obs(self): | ||||
|         ball_pos = self.sim.data.body_xpos[self.ball_id] | ||||
|         ball_pos = self.sim.data.body_xpos[self.ball_id][:2].copy() | ||||
|         goal_pos = self.goal[:2].copy() | ||||
|         obs = np.concatenate([self.sim.data.qpos[:7].copy(),  # 7 joint positions | ||||
|                               ball_pos, | ||||
|                               self.goal.copy()]) | ||||
|                               goal_pos]) | ||||
|         return obs | ||||
| 
 | ||||
|     def sample_context(self): | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user