diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py index f8467c4..a774d5d 100644 --- a/fancy_gym/envs/__init__.py +++ b/fancy_gym/envs/__init__.py @@ -29,6 +29,7 @@ from .mujoco.table_tennis.table_tennis_env import TableTennisEnv, TableTennisWin MAX_EPISODE_STEPS_TABLE_TENNIS, MAX_EPISODE_STEPS_TABLE_TENNIS_MARKOV_VER from .mujoco.table_tennis.mp_wrapper import TT_MPWrapper as MPWrapper_TableTennis from .mujoco.table_tennis.mp_wrapper import TT_MPWrapper_Replan as MPWrapper_TableTennis_Replan +from .mujoco.table_tennis.mp_wrapper import TTRndRobot_MPWrapper as MPWrapper_TableTennis_Rnd from .mujoco.table_tennis.mp_wrapper import TTVelObs_MPWrapper as MPWrapper_TableTennis_VelObs from .mujoco.table_tennis.mp_wrapper import TTVelObs_MPWrapper_Replan as MPWrapper_TableTennis_VelObs_Replan @@ -306,6 +307,7 @@ register( register( id='fancy/TableTennisRndRobot-v0', entry_point='fancy_gym.envs.mujoco:TableTennisRandomInit', + mp_wrapper=MPWrapper_TableTennis_Rnd, max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS, kwargs={ 'random_pos_scale': 0.1, @@ -315,6 +317,7 @@ register( register( id='fancy/TableTennisMarkovian-v0', + mp_wrapper=MPWrapper_TableTennis, entry_point='fancy_gym.envs.mujoco:TableTennisMarkovian', max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS_MARKOV_VER, kwargs={ @@ -323,6 +326,7 @@ register( register( id='fancy/TableTennisRndRobotMarkovian-v0', + mp_wrapper=MPWrapper_TableTennis_Rnd, entry_point='fancy_gym.envs.mujoco:TableTennisMarkovian', max_episode_steps=MAX_EPISODE_STEPS_TABLE_TENNIS_MARKOV_VER, kwargs={ diff --git a/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py b/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py index fcc31a8..eb27c3e 100644 --- a/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py +++ b/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py @@ -151,3 +151,15 @@ class TTVelObs_MPWrapper_Replan(TT_MPWrapper_Replan): [True] * 2, # target landing position # [True] * 1, # time ]) + +class TTRndRobot_MPWrapper(TT_MPWrapper): + @property + def context_mask(self): + return np.hstack([ + [True] * 7, # joints position + [False] * 7, # joints velocity + [True] * 2, # position ball x, y + [False] * 1, # position ball z + [True] * 2, # target landing position + # [True] * 1, # time + ]) \ No newline at end of file