From 4921cc4b0b3df0fd4c1fe3ed1e678e5346507e45 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 11 Jun 2023 13:45:44 +0200 Subject: [PATCH] Fix: Some envs had wrong obs space shapes and did not follow new gym spec --- fancy_gym/envs/mujoco/beerpong/beerpong.py | 2 +- .../mujoco/box_pushing/box_pushing_env.py | 12 ++++++---- .../mujoco/hopper_jump/hopper_jump_on_box.py | 8 ++++--- .../mujoco/table_tennis/table_tennis_env.py | 22 +++++++++++++------ 4 files changed, 29 insertions(+), 15 deletions(-) diff --git a/fancy_gym/envs/mujoco/beerpong/beerpong.py b/fancy_gym/envs/mujoco/beerpong/beerpong.py index 8e2f9fc..fd1a5dc 100644 --- a/fancy_gym/envs/mujoco/beerpong/beerpong.py +++ b/fancy_gym/envs/mujoco/beerpong/beerpong.py @@ -77,7 +77,7 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle): self.dist_ground_cup = -1 # distance floor to cup if first floor contact self.observation_space = Box( - low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64 + low=-np.inf, high=np.inf, shape=(29,), dtype=np.float64 ) MujocoEnv.__init__( diff --git a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py index 3efcf3f..4fafd44 100644 --- a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py +++ b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py @@ -1,8 +1,8 @@ import os import numpy as np -from gym import utils, spaces -from gym.envs.mujoco import MujocoEnv +from gymnasium import utils, spaces +from gymnasium.envs.mujoco import MujocoEnv from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import rot_to_quat, get_quaternion_error, rotation_distance from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import q_max, q_min, q_dot_max, q_torque_max from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import desired_rod_quat @@ -51,7 +51,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): self._episode_energy = 0. self.observation_space = spaces.Box( - low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64 + low=-np.inf, high=np.inf, shape=(28,), dtype=np.float64 ) MujocoEnv.__init__(self, @@ -103,7 +103,11 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): 'is_success': True if episode_end and box_goal_pos_dist < 0.05 and box_goal_quat_dist < 0.5 else False, 'num_steps': self._steps } - return obs, reward, episode_end, infos + + terminated = episode_end and infos['is_success'] + truncated = episode_end and not infos['is_success'] + + return obs, reward, terminated, truncated, infos def reset_model(self): # rest box to initial position diff --git a/fancy_gym/envs/mujoco/hopper_jump/hopper_jump_on_box.py b/fancy_gym/envs/mujoco/hopper_jump/hopper_jump_on_box.py index 60d387a..c8c15c3 100644 --- a/fancy_gym/envs/mujoco/hopper_jump/hopper_jump_on_box.py +++ b/fancy_gym/envs/mujoco/hopper_jump/hopper_jump_on_box.py @@ -40,11 +40,11 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML): if exclude_current_positions_from_observation: self.observation_space = spaces.Box( - low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64 + low=-np.inf, high=np.inf, shape=(12,), dtype=np.float64 ) else: self.observation_space = spaces.Box( - low=-np.inf, high=np.inf, shape=(14,), dtype=np.float64 + low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64 ) xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file) @@ -136,7 +136,9 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML): 'goal': self.box_x, } - return observation, reward, terminated, info + truncated = self.current_step >= self.max_episode_steps and not terminated + + return observation, reward, terminated, truncated, info def _get_obs(self): return np.append(super()._get_obs(), self.box_x) diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py index ddf5022..55aa77c 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py @@ -1,8 +1,8 @@ import os import numpy as np -from gym import utils, spaces -from gym.envs.mujoco import MujocoEnv +from gymnasium import utils, spaces +from gymnasium.envs.mujoco import MujocoEnv from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import is_init_state_valid, magnus_force from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound @@ -60,9 +60,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): self._artificial_force = 0. - self.observation_space = spaces.Box( - low=-np.inf, high=np.inf, shape=(9,), dtype=np.float64 - ) + if not hasattr(self, 'observation_space'): + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64 + ) MujocoEnv.__init__(self, model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), @@ -146,7 +147,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): land_dist_err = np.linalg.norm(self._ball_landing_pos[:-1] - self._goal_pos) \ if self._ball_landing_pos is not None else 10. - return self._get_obs(), reward, self._terminated, { + info = { "hit_ball": self._hit_ball, "ball_returned_success": self._ball_return_success, "land_dist_error": land_dist_err, @@ -154,6 +155,10 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): "num_steps": self._steps, } + terminated, truncated = self._terminated, False + + return self._get_obs(), reward, terminated, truncated, info + def _contact_checker(self, id_1, id_2): for coni in range(0, self.data.ncon): con = self.data.contact[coni] @@ -251,7 +256,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs): obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj penalty = self._get_traj_invalid_penalty(action, pos_traj) - return obs, penalty, True, { + return obs, penalty, True, False, { "hit_ball": [False], "ball_returned_success": [False], "land_dist_error": [10.], @@ -271,6 +276,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): class TableTennisWind(TableTennisEnv): def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4): + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(22,), dtype=np.float64 + ) super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip, enable_artificial_wind=True) def _get_obs(self):