Fix: Some envs had wrong obs space shapes and did not follow new gym spec

This commit is contained in:
Dominik Moritz Roth 2023-06-11 13:45:44 +02:00
parent 80de15fd14
commit 4921cc4b0b
4 changed files with 29 additions and 15 deletions

View File

@ -77,7 +77,7 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
self.dist_ground_cup = -1 # distance floor to cup if first floor contact
self.observation_space = Box(
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64
low=-np.inf, high=np.inf, shape=(29,), dtype=np.float64
)
MujocoEnv.__init__(

View File

@ -1,8 +1,8 @@
import os
import numpy as np
from gym import utils, spaces
from gym.envs.mujoco import MujocoEnv
from gymnasium import utils, spaces
from gymnasium.envs.mujoco import MujocoEnv
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import rot_to_quat, get_quaternion_error, rotation_distance
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import q_max, q_min, q_dot_max, q_torque_max
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import desired_rod_quat
@ -51,7 +51,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
self._episode_energy = 0.
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64
low=-np.inf, high=np.inf, shape=(28,), dtype=np.float64
)
MujocoEnv.__init__(self,
@ -103,7 +103,11 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
'is_success': True if episode_end and box_goal_pos_dist < 0.05 and box_goal_quat_dist < 0.5 else False,
'num_steps': self._steps
}
return obs, reward, episode_end, infos
terminated = episode_end and infos['is_success']
truncated = episode_end and not infos['is_success']
return obs, reward, terminated, truncated, infos
def reset_model(self):
# rest box to initial position

View File

@ -40,11 +40,11 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
if exclude_current_positions_from_observation:
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64
low=-np.inf, high=np.inf, shape=(12,), dtype=np.float64
)
else:
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(14,), dtype=np.float64
low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64
)
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
@ -136,7 +136,9 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
'goal': self.box_x,
}
return observation, reward, terminated, info
truncated = self.current_step >= self.max_episode_steps and not terminated
return observation, reward, terminated, truncated, info
def _get_obs(self):
return np.append(super()._get_obs(), self.box_x)

View File

@ -1,8 +1,8 @@
import os
import numpy as np
from gym import utils, spaces
from gym.envs.mujoco import MujocoEnv
from gymnasium import utils, spaces
from gymnasium.envs.mujoco import MujocoEnv
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import is_init_state_valid, magnus_force
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
@ -60,9 +60,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self._artificial_force = 0.
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(9,), dtype=np.float64
)
if not hasattr(self, 'observation_space'):
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
)
MujocoEnv.__init__(self,
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
@ -146,7 +147,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \
if self._ball_landing_pos is not None else 10.
return self._get_obs(), reward, self._terminated, {
info = {
"hit_ball": self._hit_ball,
"ball_returned_success": self._ball_return_success,
"land_dist_error": land_dist_err,
@ -154,6 +155,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
"num_steps": self._steps,
}
terminated, truncated = self._terminated, False
return self._get_obs(), reward, terminated, truncated, info
def _contact_checker(self, id_1, id_2):
for coni in range(0, self.data.ncon):
con = self.data.contact[coni]
@ -251,7 +256,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs):
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
penalty = self._get_traj_invalid_penalty(action, pos_traj)
return obs, penalty, True, {
return obs, penalty, True, False, {
"hit_ball": [False],
"ball_returned_success": [False],
"land_dist_error": [10.],
@ -271,6 +276,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
class TableTennisWind(TableTennisEnv):
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4):
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(22,), dtype=np.float64
)
super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip, enable_artificial_wind=True)
def _get_obs(self):