Fix: Some envs had wrong obs space shapes and did not follow new gym spec
This commit is contained in:
parent
80de15fd14
commit
4921cc4b0b
@ -77,7 +77,7 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
||||
self.dist_ground_cup = -1 # distance floor to cup if first floor contact
|
||||
|
||||
self.observation_space = Box(
|
||||
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64
|
||||
low=-np.inf, high=np.inf, shape=(29,), dtype=np.float64
|
||||
)
|
||||
|
||||
MujocoEnv.__init__(
|
||||
|
@ -1,8 +1,8 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from gym import utils, spaces
|
||||
from gym.envs.mujoco import MujocoEnv
|
||||
from gymnasium import utils, spaces
|
||||
from gymnasium.envs.mujoco import MujocoEnv
|
||||
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import rot_to_quat, get_quaternion_error, rotation_distance
|
||||
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import q_max, q_min, q_dot_max, q_torque_max
|
||||
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import desired_rod_quat
|
||||
@ -51,7 +51,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
||||
self._episode_energy = 0.
|
||||
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64
|
||||
low=-np.inf, high=np.inf, shape=(28,), dtype=np.float64
|
||||
)
|
||||
|
||||
MujocoEnv.__init__(self,
|
||||
@ -103,7 +103,11 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
||||
'is_success': True if episode_end and box_goal_pos_dist < 0.05 and box_goal_quat_dist < 0.5 else False,
|
||||
'num_steps': self._steps
|
||||
}
|
||||
return obs, reward, episode_end, infos
|
||||
|
||||
terminated = episode_end and infos['is_success']
|
||||
truncated = episode_end and not infos['is_success']
|
||||
|
||||
return obs, reward, terminated, truncated, infos
|
||||
|
||||
def reset_model(self):
|
||||
# rest box to initial position
|
||||
|
@ -40,11 +40,11 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
|
||||
|
||||
if exclude_current_positions_from_observation:
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64
|
||||
low=-np.inf, high=np.inf, shape=(12,), dtype=np.float64
|
||||
)
|
||||
else:
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(14,), dtype=np.float64
|
||||
low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64
|
||||
)
|
||||
|
||||
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
|
||||
@ -136,7 +136,9 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
|
||||
'goal': self.box_x,
|
||||
}
|
||||
|
||||
return observation, reward, terminated, info
|
||||
truncated = self.current_step >= self.max_episode_steps and not terminated
|
||||
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def _get_obs(self):
|
||||
return np.append(super()._get_obs(), self.box_x)
|
||||
|
@ -1,8 +1,8 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from gym import utils, spaces
|
||||
from gym.envs.mujoco import MujocoEnv
|
||||
from gymnasium import utils, spaces
|
||||
from gymnasium.envs.mujoco import MujocoEnv
|
||||
|
||||
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import is_init_state_valid, magnus_force
|
||||
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
|
||||
@ -60,9 +60,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
self._artificial_force = 0.
|
||||
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(9,), dtype=np.float64
|
||||
)
|
||||
if not hasattr(self, 'observation_space'):
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
|
||||
)
|
||||
|
||||
MujocoEnv.__init__(self,
|
||||
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
||||
@ -146,7 +147,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \
|
||||
if self._ball_landing_pos is not None else 10.
|
||||
|
||||
return self._get_obs(), reward, self._terminated, {
|
||||
info = {
|
||||
"hit_ball": self._hit_ball,
|
||||
"ball_returned_success": self._ball_return_success,
|
||||
"land_dist_error": land_dist_err,
|
||||
@ -154,6 +155,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
"num_steps": self._steps,
|
||||
}
|
||||
|
||||
terminated, truncated = self._terminated, False
|
||||
|
||||
return self._get_obs(), reward, terminated, truncated, info
|
||||
|
||||
def _contact_checker(self, id_1, id_2):
|
||||
for coni in range(0, self.data.ncon):
|
||||
con = self.data.contact[coni]
|
||||
@ -251,7 +256,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs):
|
||||
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
|
||||
penalty = self._get_traj_invalid_penalty(action, pos_traj)
|
||||
return obs, penalty, True, {
|
||||
return obs, penalty, True, False, {
|
||||
"hit_ball": [False],
|
||||
"ball_returned_success": [False],
|
||||
"land_dist_error": [10.],
|
||||
@ -271,6 +276,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
class TableTennisWind(TableTennisEnv):
|
||||
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4):
|
||||
self.observation_space = spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(22,), dtype=np.float64
|
||||
)
|
||||
super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip, enable_artificial_wind=True)
|
||||
|
||||
def _get_obs(self):
|
||||
|
Loading…
Reference in New Issue
Block a user