diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py index 3d6f042..438f322 100644 --- a/fancy_gym/envs/__init__.py +++ b/fancy_gym/envs/__init__.py @@ -31,7 +31,6 @@ from .mujoco.table_tennis.mp_wrapper import TT_MPWrapper as MPWrapper_TableTenni from .mujoco.table_tennis.mp_wrapper import TT_MPWrapper_Replan as MPWrapper_TableTennis_Replan from .mujoco.table_tennis.mp_wrapper import TTVelObs_MPWrapper as MPWrapper_TableTennis_VelObs from .mujoco.table_tennis.mp_wrapper import TTVelObs_MPWrapper_Replan as MPWrapper_TableTennis_VelObs_Replan -from .mujoco.air_hockey.air_hockey_env_wrapper import MAX_EPISODE_STEPS_AIRHOCKEY # Classic Control # Simple Reacher @@ -294,5 +293,5 @@ register( register( id='fancy/AirHockey-v0', entry_point='fancy_gym.envs.mujoco:AirHockeyEnv', - max_episode_steps=MAX_EPISODE_STEPS_AIRHOCKEY -) + max_episode_steps=45000 +) \ No newline at end of file diff --git a/fancy_gym/envs/mujoco/__init__.py b/fancy_gym/envs/mujoco/__init__.py index 9be0583..7e4be42 100644 --- a/fancy_gym/envs/mujoco/__init__.py +++ b/fancy_gym/envs/mujoco/__init__.py @@ -9,4 +9,9 @@ from .reacher.reacher import ReacherEnv from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv from .box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, BoxPushingTemporalSpatialSparse from .table_tennis.table_tennis_env import TableTennisEnv, TableTennisWind, TableTennisGoalSwitching -from .air_hockey.air_hockey_env_wrapper import AirHockeyEnv \ No newline at end of file + +from .air_hockey.air_hockey_env_wrapper import AirHockeyEnv +try: + from .air_hockey.air_hockey_env_wrapper import AirHockeyEnv +except: + print("[FANCY GYM] Air Hockey not available (depends on mushroom-rl, dmc, mujoco)") \ No newline at end of file diff --git a/fancy_gym/envs/mujoco/air_hockey/seven_dof/env_double.py b/fancy_gym/envs/mujoco/air_hockey/seven_dof/env_double.py index 086812c..54a5252 100644 --- a/fancy_gym/envs/mujoco/air_hockey/seven_dof/env_double.py +++ b/fancy_gym/envs/mujoco/air_hockey/seven_dof/env_double.py @@ -129,8 +129,8 @@ class AirHockeyDouble(AirHockeyBase): self.q_pos_prev[i] = self.init_state[i] self.q_pos_prev[i + 7] = self.init_state[i] - self.q_vel_prev[i] = self._data.joint("iiwa_1/joint_" + str(i + 1)).qvel - self.q_vel_prev[i + 7] = self._data.joint("iiwa_2/joint_" + str(i + 1)).qvel + self.q_vel_prev[i] = self._data.joint("iiwa_1/joint_" + str(i + 1)).qvel[0] + self.q_vel_prev[i + 7] = self._data.joint("iiwa_2/joint_" + str(i + 1)).qvel[0] self.universal_joint_plugin.reset() diff --git a/setup.py b/setup.py index 1daa568..8926b65 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ extras = { 'mujoco': ['mujoco==2.3.3', 'gymnasium[mujoco]>0.26.0'], 'mujoco-legacy': ['mujoco-py >=2.1,<2.2', 'cython<3'], 'jax': ["jax >=0.4.0", "jaxlib >=0.4.0"], + 'mushroom-rl': ['mushroom-rl'], } # All dependencies