fix tt issues -> context + traj.length
This commit is contained in:
parent
66be0b1e02
commit
2a27f59e50
@ -236,6 +236,17 @@ register(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Beerpong devel big table
|
||||||
|
register(
|
||||||
|
id='ALRBeerPong-v3',
|
||||||
|
entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv',
|
||||||
|
max_episode_steps=600,
|
||||||
|
kwargs={
|
||||||
|
"rndm_goal": True,
|
||||||
|
"cup_goal_pos": [-0.3, -1.2]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Motion Primitive Environments
|
# Motion Primitive Environments
|
||||||
|
|
||||||
## Simple Reacher
|
## Simple Reacher
|
||||||
@ -402,6 +413,32 @@ for _v in _versions:
|
|||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
|
## Beerpong- Big table devel
|
||||||
|
|
||||||
|
register(
|
||||||
|
id='BeerpongProMP-v3',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": f"alr_envs:ALRBeerPong-v3",
|
||||||
|
"wrappers": [mujoco.beerpong.MPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 7,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 1,
|
||||||
|
"post_traj_time": 2,
|
||||||
|
"policy_type": "motor",
|
||||||
|
"weights_scale": 1,
|
||||||
|
"zero_start": True,
|
||||||
|
"zero_goal": False,
|
||||||
|
"policy_kwargs": {
|
||||||
|
"p_gains": np.array([ 1.5, 5, 2.55, 3, 2., 2, 1.25]),
|
||||||
|
"d_gains": np.array([0.02333333, 0.1, 0.0625, 0.08, 0.03, 0.03, 0.0125])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append('BeerpongProMP-v3')
|
||||||
|
|
||||||
## Table Tennis
|
## Table Tennis
|
||||||
ctxt_dim = [2, 4]
|
ctxt_dim = [2, 4]
|
||||||
for _v, cd in enumerate(ctxt_dim):
|
for _v, cd in enumerate(ctxt_dim):
|
||||||
@ -416,7 +453,7 @@ for _v, cd in enumerate(ctxt_dim):
|
|||||||
"num_dof": 7,
|
"num_dof": 7,
|
||||||
"num_basis": 2,
|
"num_basis": 2,
|
||||||
"duration": 1.25,
|
"duration": 1.25,
|
||||||
"post_traj_time": 4.5,
|
"post_traj_time": 1.5,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 1.0,
|
"weights_scale": 1.0,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
|
@ -132,18 +132,19 @@
|
|||||||
</body>
|
</body>
|
||||||
</body>
|
</body>
|
||||||
</body>
|
</body>
|
||||||
<body name="table_body" pos="0 -1.85 0.4025">
|
<body name="table_body" pos="0 -2.8 0.4025">
|
||||||
<geom name="table" type="box" size="0.4 0.6 0.4" rgba="0.8 0.655 0.45 1" solimp="0.999 0.999 0.001"
|
<geom name="table" type="box" size="1.5 1.5 0.4" rgba="0.8 0.655 0.45 1" solimp="0.999 0.999 0.001"
|
||||||
solref="-10000 -100"/>
|
solref="-10000 -100"/>
|
||||||
<geom name="table_contact_geom" type="box" size="0.4 0.6 0.1" pos="0 0 0.31" rgba="1.4 0.8 0.45 1" solimp="0.999 0.999 0.001"
|
<geom name="table_contact_geom" type="box" size="1.5 1.5 0.1" pos="0 0 0.31" rgba="1.4 0.8 0.45 1" solimp="0.999 0.999 0.001"
|
||||||
solref="-10000 -100"/>
|
solref="-10000 -100"/>
|
||||||
</body>
|
</body>
|
||||||
<geom name="table_robot" type="box" size="0.1 0.1 0.3" pos="0 0.00 0.3025" rgba="0.8 0.655 0.45 1" solimp="0.999 0.999 0.001"
|
<geom name="table_robot" type="box" size="0.1 0.1 0.3" pos="0 0.00 0.3025" rgba="0.8 0.655 0.45 1" solimp="0.999 0.999 0.001"
|
||||||
solref="-10000 -100"/>
|
solref="-10000 -100"/>
|
||||||
<geom name="wall" type="box" quat="1 0 0 0" size="0.4 0.04 1.1" pos="0. -2.45 1.1" rgba="0.8 0.655 0.45 1" solimp="0.999 0.999 0.001"
|
<geom name="wall" type="box" quat="1 0 0 0" size="1.5 0.04 1.4" pos="0. -4.3 1.4" rgba="0.8 0.655 0.45 1" solimp="0.999 0.999 0.001"
|
||||||
solref="-10000 -100"/>
|
solref="-10000 -100"/>
|
||||||
|
|
||||||
<body name="cup_table" pos="0.32 -1.55 0.84" quat="0.7071068 0.7071068 0 0">
|
<!-- <body name="cup_table" pos="0.32 -1.55 0.84" quat="0.7071068 0.7071068 0 0">-->
|
||||||
|
<body name="cup_table" pos="1.42 -1.25 0.84" quat="0.7071068 0.7071068 0 0">
|
||||||
<inertial pos="-3.75236e-10 8.27811e-05 0.0947015" quat="0.999945 -0.0104888 0 0" mass="10.132" diaginertia="0.000285643 0.000270485 9.65696e-05" />
|
<inertial pos="-3.75236e-10 8.27811e-05 0.0947015" quat="0.999945 -0.0104888 0 0" mass="10.132" diaginertia="0.000285643 0.000270485 9.65696e-05" />
|
||||||
<geom priority= "1" name="cup_geom_table3" pos="0 0.1 0.001" euler="-1.57 0 0" solref="-10000 -100" type="mesh" mesh="cup3_table" mass="10"/>
|
<geom priority= "1" name="cup_geom_table3" pos="0 0.1 0.001" euler="-1.57 0 0" solref="-10000 -100" type="mesh" mesh="cup3_table" mass="10"/>
|
||||||
<geom priority= "1" name="cup_geom_table4" pos="0 0.1 0.001" euler="-1.57 0 0" solref="-10000 -100" type="mesh" mesh="cup4_table" mass="10"/>
|
<geom priority= "1" name="cup_geom_table4" pos="0 0.1 0.001" euler="-1.57 0 0" solref="-10000 -100" type="mesh" mesh="cup4_table" mass="10"/>
|
||||||
|
@ -7,8 +7,11 @@ from gym.envs.mujoco import MujocoEnv
|
|||||||
from alr_envs.alr.mujoco.beerpong.beerpong_reward_staged import BeerPongReward
|
from alr_envs.alr.mujoco.beerpong.beerpong_reward_staged import BeerPongReward
|
||||||
|
|
||||||
|
|
||||||
CUP_POS_MIN = np.array([-0.32, -2.2])
|
# CUP_POS_MIN = np.array([-0.32, -2.2])
|
||||||
CUP_POS_MAX = np.array([0.32, -1.2])
|
# CUP_POS_MAX = np.array([0.32, -1.2])
|
||||||
|
|
||||||
|
CUP_POS_MIN = np.array([-1.42, -4.05])
|
||||||
|
CUP_POS_MAX = np.array([1.42, -1.25])
|
||||||
|
|
||||||
|
|
||||||
class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
class ALRBeerBongEnv(MujocoEnv, utils.EzPickle):
|
||||||
|
@ -11,9 +11,9 @@ class MPWrapper(MPEnvWrapper):
|
|||||||
def active_obs(self):
|
def active_obs(self):
|
||||||
# TODO: @Max Filter observations correctly
|
# TODO: @Max Filter observations correctly
|
||||||
return np.hstack([
|
return np.hstack([
|
||||||
[True] * 7, # Joint Pos
|
[False] * 7, # Joint Pos
|
||||||
[True] * 3, # Ball pos
|
[True] * 2, # Ball pos
|
||||||
[True] * 3 # goal pos
|
[True] * 2 # goal pos
|
||||||
])
|
])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -10,7 +10,8 @@ from alr_envs.alr.mujoco.table_tennis.tt_reward import TT_Reward
|
|||||||
|
|
||||||
#TODO: Check for simulation stability. Make sure the code runs even for sim crash
|
#TODO: Check for simulation stability. Make sure the code runs even for sim crash
|
||||||
|
|
||||||
MAX_EPISODE_STEPS = 1750
|
# MAX_EPISODE_STEPS = 1750
|
||||||
|
MAX_EPISODE_STEPS = 1375
|
||||||
BALL_NAME_CONTACT = "target_ball_contact"
|
BALL_NAME_CONTACT = "target_ball_contact"
|
||||||
BALL_NAME = "target_ball"
|
BALL_NAME = "target_ball"
|
||||||
TABLE_NAME = 'table_tennis_table'
|
TABLE_NAME = 'table_tennis_table'
|
||||||
@ -76,10 +77,11 @@ class TTEnvGym(MujocoEnv, utils.EzPickle):
|
|||||||
self._ids_set = True
|
self._ids_set = True
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
ball_pos = self.sim.data.body_xpos[self.ball_id]
|
ball_pos = self.sim.data.body_xpos[self.ball_id][:2].copy()
|
||||||
|
goal_pos = self.goal[:2].copy()
|
||||||
obs = np.concatenate([self.sim.data.qpos[:7].copy(), # 7 joint positions
|
obs = np.concatenate([self.sim.data.qpos[:7].copy(), # 7 joint positions
|
||||||
ball_pos,
|
ball_pos,
|
||||||
self.goal.copy()])
|
goal_pos])
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def sample_context(self):
|
def sample_context(self):
|
||||||
|
2
setup.py
2
setup.py
@ -7,7 +7,7 @@ setup(
|
|||||||
install_requires=[
|
install_requires=[
|
||||||
'gym',
|
'gym',
|
||||||
'PyQt5',
|
'PyQt5',
|
||||||
'matplotlib',
|
#'matplotlib',
|
||||||
#'mp_env_api @ git+https://github.com/ALRhub/motion_primitive_env_api.git',
|
#'mp_env_api @ git+https://github.com/ALRhub/motion_primitive_env_api.git',
|
||||||
# 'mp_env_api @ git+ssh://git@github.com/ALRhub/motion_primitive_env_api.git',
|
# 'mp_env_api @ git+ssh://git@github.com/ALRhub/motion_primitive_env_api.git',
|
||||||
'mujoco-py<2.1,>=2.0',
|
'mujoco-py<2.1,>=2.0',
|
||||||
|
Loading…
Reference in New Issue
Block a user