Bugfixes and update for deprecated code

This commit is contained in:
Mustafa Enes Batur 2023-11-19 22:45:32 +01:00
parent d0cb6316a5
commit 9aa572271f
7 changed files with 9 additions and 9 deletions

View File

@ -10,8 +10,7 @@ from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv
from .box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, BoxPushingTemporalSpatialSparse
from .table_tennis.table_tennis_env import TableTennisEnv, TableTennisWind, TableTennisGoalSwitching
from .air_hockey.air_hockey_env_wrapper import AirHockeyEnv
try:
from .air_hockey.air_hockey_env_wrapper import AirHockeyEnv
except:
except ModuleNotFoundError:
print("[FANCY GYM] Air Hockey not available (depends on mushroom-rl, dmc, mujoco)")

View File

@ -45,10 +45,11 @@ class AirHockeyEnv(Environment):
self.base_env = env_dict[env_mode](interpolation_order=interpolation_order, **kwargs)
self.env_name = env_mode
self.env_info = self.base_env.env_info
single_robot_obs_size = len(self.base_env.info.observation_space.low)
if env_mode == "tournament":
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(2,23), dtype=np.float64)
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(2,single_robot_obs_size), dtype=np.float64)
else:
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(20,), dtype=np.float64)
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(single_robot_obs_size,), dtype=np.float64)
robot_info = self.env_info["robot"]
if env_mode != "tournament":

View File

@ -24,7 +24,7 @@ class AirHockeyDefend(AirHockeySingle):
puck_vel = np.zeros(3)
puck_vel[0] = -np.cos(angle) * lin_vel
puck_vel[1] = np.sin(angle) * lin_vel
puck_vel[2] = np.random.uniform(-10, 10, 1)
puck_vel[2] = np.random.uniform(-10, 10)
self._write_data("puck_x_pos", puck_pos[0])
self._write_data("puck_y_pos", puck_pos[1])

View File

@ -94,7 +94,7 @@ class AirHockeySingle(AirHockeyBase):
for i in range(7):
self._data.joint("iiwa_1/joint_" + str(i + 1)).qpos = self.init_state[i]
self.q_pos_prev[i] = self.init_state[i]
self.q_vel_prev[i] = self._data.joint("iiwa_1/joint_" + str(i + 1)).qvel
self.q_vel_prev[i] = self._data.joint("iiwa_1/joint_" + str(i + 1)).qvel[0]
self.universal_joint_plugin.reset()

View File

@ -40,7 +40,7 @@ class AirHockeyHit(AirHockeySingle):
puck_vel = np.zeros(3)
puck_vel[0] = -np.cos(angle) * lin_vel
puck_vel[1] = np.sin(angle) * lin_vel
puck_vel[2] = np.random.uniform(-2, 2, 1)
puck_vel[2] = np.random.uniform(-2, 2)
self._write_data("puck_x_vel", puck_vel[0])
self._write_data("puck_y_vel", puck_vel[1])

View File

@ -27,7 +27,7 @@ class AirHockeyDefend(AirHockeySingle):
puck_vel = np.zeros(3)
puck_vel[0] = -np.cos(angle) * lin_vel
puck_vel[1] = np.sin(angle) * lin_vel
puck_vel[2] = np.random.uniform(-10, 10, 1)
puck_vel[2] = np.random.uniform(-10, 10)
self._write_data("puck_x_pos", puck_pos[0])
self._write_data("puck_y_pos", puck_pos[1])

View File

@ -71,7 +71,7 @@ class AirHockeySingle(AirHockeyBase):
for i in range(3):
self._data.joint("planar_robot_1/joint_" + str(i + 1)).qpos = self.init_state[i]
self.q_pos_prev[i] = self.init_state[i]
self.q_vel_prev[i] = self._data.joint("planar_robot_1/joint_" + str(i + 1)).qvel
self.q_vel_prev[i] = self._data.joint("planar_robot_1/joint_" + str(i + 1)).qvel[0]
mujoco.mj_fwdPosition(self._model, self._data)
super().setup(state)