Fix: Some envs had wrong obs space shapes and did not follow new gym spec

This commit is contained in:
Dominik Moritz Roth 2023-06-11 13:45:44 +02:00
parent 80de15fd14
commit 4921cc4b0b
4 changed files with 29 additions and 15 deletions

View File

@ -77,7 +77,7 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
self.dist_ground_cup = -1 # distance floor to cup if first floor contact self.dist_ground_cup = -1 # distance floor to cup if first floor contact
self.observation_space = Box( self.observation_space = Box(
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64 low=-np.inf, high=np.inf, shape=(29,), dtype=np.float64
) )
MujocoEnv.__init__( MujocoEnv.__init__(

View File

@ -1,8 +1,8 @@
import os import os
import numpy as np import numpy as np
from gym import utils, spaces from gymnasium import utils, spaces
from gym.envs.mujoco import MujocoEnv from gymnasium.envs.mujoco import MujocoEnv
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import rot_to_quat, get_quaternion_error, rotation_distance from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import rot_to_quat, get_quaternion_error, rotation_distance
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import q_max, q_min, q_dot_max, q_torque_max from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import q_max, q_min, q_dot_max, q_torque_max
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import desired_rod_quat from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import desired_rod_quat
@ -51,7 +51,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
self._episode_energy = 0. self._episode_energy = 0.
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64 low=-np.inf, high=np.inf, shape=(28,), dtype=np.float64
) )
MujocoEnv.__init__(self, MujocoEnv.__init__(self,
@ -103,7 +103,11 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
'is_success': True if episode_end and box_goal_pos_dist < 0.05 and box_goal_quat_dist < 0.5 else False, 'is_success': True if episode_end and box_goal_pos_dist < 0.05 and box_goal_quat_dist < 0.5 else False,
'num_steps': self._steps 'num_steps': self._steps
} }
return obs, reward, episode_end, infos
terminated = episode_end and infos['is_success']
truncated = episode_end and not infos['is_success']
return obs, reward, terminated, truncated, infos
def reset_model(self): def reset_model(self):
# rest box to initial position # rest box to initial position

View File

@ -40,11 +40,11 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
if exclude_current_positions_from_observation: if exclude_current_positions_from_observation:
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64 low=-np.inf, high=np.inf, shape=(12,), dtype=np.float64
) )
else: else:
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(14,), dtype=np.float64 low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64
) )
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file) xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
@ -136,7 +136,9 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
'goal': self.box_x, 'goal': self.box_x,
} }
return observation, reward, terminated, info truncated = self.current_step >= self.max_episode_steps and not terminated
return observation, reward, terminated, truncated, info
def _get_obs(self): def _get_obs(self):
return np.append(super()._get_obs(), self.box_x) return np.append(super()._get_obs(), self.box_x)

View File

@ -1,8 +1,8 @@
import os import os
import numpy as np import numpy as np
from gym import utils, spaces from gymnasium import utils, spaces
from gym.envs.mujoco import MujocoEnv from gymnasium.envs.mujoco import MujocoEnv
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import is_init_state_valid, magnus_force from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import is_init_state_valid, magnus_force
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
@ -60,9 +60,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self._artificial_force = 0. self._artificial_force = 0.
self.observation_space = spaces.Box( if not hasattr(self, 'observation_space'):
low=-np.inf, high=np.inf, shape=(9,), dtype=np.float64 self.observation_space = spaces.Box(
) low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
)
MujocoEnv.__init__(self, MujocoEnv.__init__(self,
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
@ -146,7 +147,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \ land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \
if self._ball_landing_pos is not None else 10. if self._ball_landing_pos is not None else 10.
return self._get_obs(), reward, self._terminated, { info = {
"hit_ball": self._hit_ball, "hit_ball": self._hit_ball,
"ball_returned_success": self._ball_return_success, "ball_returned_success": self._ball_return_success,
"land_dist_error": land_dist_err, "land_dist_error": land_dist_err,
@ -154,6 +155,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
"num_steps": self._steps, "num_steps": self._steps,
} }
terminated, truncated = self._terminated, False
return self._get_obs(), reward, terminated, truncated, info
def _contact_checker(self, id_1, id_2): def _contact_checker(self, id_1, id_2):
for coni in range(0, self.data.ncon): for coni in range(0, self.data.ncon):
con = self.data.contact[coni] con = self.data.contact[coni]
@ -251,7 +256,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs): def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs):
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
penalty = self._get_traj_invalid_penalty(action, pos_traj) penalty = self._get_traj_invalid_penalty(action, pos_traj)
return obs, penalty, True, { return obs, penalty, True, False, {
"hit_ball": [False], "hit_ball": [False],
"ball_returned_success": [False], "ball_returned_success": [False],
"land_dist_error": [10.], "land_dist_error": [10.],
@ -271,6 +276,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
class TableTennisWind(TableTennisEnv): class TableTennisWind(TableTennisEnv):
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4): def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4):
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(22,), dtype=np.float64
)
super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip, enable_artificial_wind=True) super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip, enable_artificial_wind=True)
def _get_obs(self): def _get_obs(self):