Fix: Some envs had wrong obs space shapes and did not follow new gym spec
This commit is contained in:
parent
80de15fd14
commit
4921cc4b0b
@ -77,7 +77,7 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self.dist_ground_cup = -1 # distance floor to cup if first floor contact
|
self.dist_ground_cup = -1 # distance floor to cup if first floor contact
|
||||||
|
|
||||||
self.observation_space = Box(
|
self.observation_space = Box(
|
||||||
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64
|
low=-np.inf, high=np.inf, shape=(29,), dtype=np.float64
|
||||||
)
|
)
|
||||||
|
|
||||||
MujocoEnv.__init__(
|
MujocoEnv.__init__(
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import utils, spaces
|
from gymnasium import utils, spaces
|
||||||
from gym.envs.mujoco import MujocoEnv
|
from gymnasium.envs.mujoco import MujocoEnv
|
||||||
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import rot_to_quat, get_quaternion_error, rotation_distance
|
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import rot_to_quat, get_quaternion_error, rotation_distance
|
||||||
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import q_max, q_min, q_dot_max, q_torque_max
|
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import q_max, q_min, q_dot_max, q_torque_max
|
||||||
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import desired_rod_quat
|
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import desired_rod_quat
|
||||||
@ -51,7 +51,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
self._episode_energy = 0.
|
self._episode_energy = 0.
|
||||||
|
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64
|
low=-np.inf, high=np.inf, shape=(28,), dtype=np.float64
|
||||||
)
|
)
|
||||||
|
|
||||||
MujocoEnv.__init__(self,
|
MujocoEnv.__init__(self,
|
||||||
@ -103,7 +103,11 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
'is_success': True if episode_end and box_goal_pos_dist < 0.05 and box_goal_quat_dist < 0.5 else False,
|
'is_success': True if episode_end and box_goal_pos_dist < 0.05 and box_goal_quat_dist < 0.5 else False,
|
||||||
'num_steps': self._steps
|
'num_steps': self._steps
|
||||||
}
|
}
|
||||||
return obs, reward, episode_end, infos
|
|
||||||
|
terminated = episode_end and infos['is_success']
|
||||||
|
truncated = episode_end and not infos['is_success']
|
||||||
|
|
||||||
|
return obs, reward, terminated, truncated, infos
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
# rest box to initial position
|
# rest box to initial position
|
||||||
|
@ -40,11 +40,11 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
|
|||||||
|
|
||||||
if exclude_current_positions_from_observation:
|
if exclude_current_positions_from_observation:
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64
|
low=-np.inf, high=np.inf, shape=(12,), dtype=np.float64
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
low=-np.inf, high=np.inf, shape=(14,), dtype=np.float64
|
low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64
|
||||||
)
|
)
|
||||||
|
|
||||||
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
|
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
|
||||||
@ -136,7 +136,9 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
|
|||||||
'goal': self.box_x,
|
'goal': self.box_x,
|
||||||
}
|
}
|
||||||
|
|
||||||
return observation, reward, terminated, info
|
truncated = self.current_step >= self.max_episode_steps and not terminated
|
||||||
|
|
||||||
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
return np.append(super()._get_obs(), self.box_x)
|
return np.append(super()._get_obs(), self.box_x)
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import utils, spaces
|
from gymnasium import utils, spaces
|
||||||
from gym.envs.mujoco import MujocoEnv
|
from gymnasium.envs.mujoco import MujocoEnv
|
||||||
|
|
||||||
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import is_init_state_valid, magnus_force
|
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import is_init_state_valid, magnus_force
|
||||||
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
|
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
|
||||||
@ -60,9 +60,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
self._artificial_force = 0.
|
self._artificial_force = 0.
|
||||||
|
|
||||||
self.observation_space = spaces.Box(
|
if not hasattr(self, 'observation_space'):
|
||||||
low=-np.inf, high=np.inf, shape=(9,), dtype=np.float64
|
self.observation_space = spaces.Box(
|
||||||
)
|
low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
|
||||||
|
)
|
||||||
|
|
||||||
MujocoEnv.__init__(self,
|
MujocoEnv.__init__(self,
|
||||||
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
||||||
@ -146,7 +147,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \
|
land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \
|
||||||
if self._ball_landing_pos is not None else 10.
|
if self._ball_landing_pos is not None else 10.
|
||||||
|
|
||||||
return self._get_obs(), reward, self._terminated, {
|
info = {
|
||||||
"hit_ball": self._hit_ball,
|
"hit_ball": self._hit_ball,
|
||||||
"ball_returned_success": self._ball_return_success,
|
"ball_returned_success": self._ball_return_success,
|
||||||
"land_dist_error": land_dist_err,
|
"land_dist_error": land_dist_err,
|
||||||
@ -154,6 +155,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
"num_steps": self._steps,
|
"num_steps": self._steps,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
terminated, truncated = self._terminated, False
|
||||||
|
|
||||||
|
return self._get_obs(), reward, terminated, truncated, info
|
||||||
|
|
||||||
def _contact_checker(self, id_1, id_2):
|
def _contact_checker(self, id_1, id_2):
|
||||||
for coni in range(0, self.data.ncon):
|
for coni in range(0, self.data.ncon):
|
||||||
con = self.data.contact[coni]
|
con = self.data.contact[coni]
|
||||||
@ -251,7 +256,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs):
|
def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs):
|
||||||
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
|
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
|
||||||
penalty = self._get_traj_invalid_penalty(action, pos_traj)
|
penalty = self._get_traj_invalid_penalty(action, pos_traj)
|
||||||
return obs, penalty, True, {
|
return obs, penalty, True, False, {
|
||||||
"hit_ball": [False],
|
"hit_ball": [False],
|
||||||
"ball_returned_success": [False],
|
"ball_returned_success": [False],
|
||||||
"land_dist_error": [10.],
|
"land_dist_error": [10.],
|
||||||
@ -271,6 +276,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
class TableTennisWind(TableTennisEnv):
|
class TableTennisWind(TableTennisEnv):
|
||||||
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4):
|
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4):
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(22,), dtype=np.float64
|
||||||
|
)
|
||||||
super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip, enable_artificial_wind=True)
|
super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip, enable_artificial_wind=True)
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user