From ef64b0c21c4b0b49070c00298c500e8e3224f62e Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sun, 11 Jun 2023 11:08:46 +0200 Subject: [PATCH] Adding/fixing obs space definitions and metadata for various envs --- fancy_gym/envs/mujoco/beerpong/beerpong.py | 17 +++++++++++++++-- .../mujoco/box_pushing/box_pushing_env.py | 16 +++++++++++++++- .../half_cheetah_jump/half_cheetah_jump.py | 4 ++-- .../envs/mujoco/hopper_jump/hopper_jump.py | 19 ++++++++++--------- .../mujoco/hopper_jump/hopper_jump_on_box.py | 11 +++++++++++ .../envs/mujoco/hopper_throw/hopper_throw.py | 11 +++++++++++ .../hopper_throw/hopper_throw_in_basket.py | 12 ++++++++++++ .../mujoco/table_tennis/table_tennis_env.py | 15 ++++++++++++++- .../mujoco/walker_2d_jump/walker_2d_jump.py | 4 ++-- 9 files changed, 92 insertions(+), 17 deletions(-) diff --git a/fancy_gym/envs/mujoco/beerpong/beerpong.py b/fancy_gym/envs/mujoco/beerpong/beerpong.py index 6a37e66..1f35bce 100644 --- a/fancy_gym/envs/mujoco/beerpong/beerpong.py +++ b/fancy_gym/envs/mujoco/beerpong/beerpong.py @@ -5,6 +5,7 @@ import numpy as np from gymnasium import utils from gymnasium.core import ObsType from gymnasium.envs.mujoco import MujocoEnv +from gymnasium.spaces import Box MAX_EPISODE_STEPS_BEERPONG = 300 FIXED_RELEASE_STEP = 62 # empirically evaluated for frame_skip=2! @@ -31,6 +32,14 @@ CUP_COLLISION_OBJ = ["cup_geom_table3", "cup_geom_table4", "cup_geom_table5", "c class BeerPongEnv(MujocoEnv, utils.EzPickle): + metadata = { + "render_modes": [ + "human", + "rgb_array", + "depth_array", + ], + } + def __init__(self, **kwargs): self._steps = 0 # Small Context -> Easier. Todo: Should we do different versions? @@ -66,6 +75,10 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle): self.ball_in_cup = False self.dist_ground_cup = -1 # distance floor to cup if first floor contact + self.observation_space = Box( + low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64 + ) + MujocoEnv.__init__( self, self.xml_path, @@ -208,13 +221,13 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle): min_dist_coeff, final_dist_coeff, ground_contact_dist_coeff, rew_offset = 0, 1, 0, 0 action_cost = 1e-4 * np.mean(action_cost) reward = rew_offset - min_dist_coeff * min_dist ** 2 - final_dist_coeff * final_dist ** 2 - \ - action_cost - ground_contact_dist_coeff * self.dist_ground_cup ** 2 + action_cost - ground_contact_dist_coeff * self.dist_ground_cup ** 2 # release step punishment min_time_bound = 0.1 max_time_bound = 1.0 release_time = self.release_step * self.dt release_time_rew = int(release_time < min_time_bound) * (-30 - 10 * (release_time - min_time_bound) ** 2) + \ - int(release_time > max_time_bound) * (-30 - 10 * (release_time - max_time_bound) ** 2) + int(release_time > max_time_bound) * (-30 - 10 * (release_time - max_time_bound) ** 2) reward += release_time_rew success = self.ball_in_cup else: diff --git a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py index 2408404..65db553 100644 --- a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py +++ b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py @@ -27,6 +27,14 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): 3. time-spatial-depend sparse reward """ + metadata = { + "render_modes": [ + "human", + "rgb_array", + "depth_array", + ], + } + def __init__(self, frame_skip: int = 10): utils.EzPickle.__init__(**locals()) self._steps = 0 @@ -40,9 +48,15 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): self._desired_rod_quat = desired_rod_quat self._episode_energy = 0. + + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64 + ) + MujocoEnv.__init__(self, model_path=os.path.join(os.path.dirname(__file__), "assets", "box_pushing.xml"), - frame_skip=self.frame_skip) + frame_skip=self.frame_skip, + observation_space=self.observation_space) self.action_space = spaces.Box(low=-1, high=1, shape=(7,)) def step(self, action): diff --git a/fancy_gym/envs/mujoco/half_cheetah_jump/half_cheetah_jump.py b/fancy_gym/envs/mujoco/half_cheetah_jump/half_cheetah_jump.py index f4bc677..4ef2757 100644 --- a/fancy_gym/envs/mujoco/half_cheetah_jump/half_cheetah_jump.py +++ b/fancy_gym/envs/mujoco/half_cheetah_jump/half_cheetah_jump.py @@ -45,11 +45,11 @@ class HalfCheetahEnvCustomXML(HalfCheetahEnv): if exclude_current_positions_from_observation: observation_space = Box( - low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64 + low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64 ) else: observation_space = Box( - low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64 + low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64 ) MujocoEnv.__init__( diff --git a/fancy_gym/envs/mujoco/hopper_jump/hopper_jump.py b/fancy_gym/envs/mujoco/hopper_jump/hopper_jump.py index 0da71db..f7936c7 100644 --- a/fancy_gym/envs/mujoco/hopper_jump/hopper_jump.py +++ b/fancy_gym/envs/mujoco/hopper_jump/hopper_jump.py @@ -67,20 +67,21 @@ class HopperEnvCustomXML(HopperEnv): exclude_current_positions_from_observation ) - if exclude_current_positions_from_observation: - observation_space = Box( - low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64 - ) - else: - observation_space = Box( - low=-np.inf, high=np.inf, shape=(14,), dtype=np.float64 - ) + if not hasattr(self, 'observation_space'): + if exclude_current_positions_from_observation: + self.observation_space = Box( + low=-np.inf, high=np.inf, shape=(15,), dtype=np.float64 + ) + else: + self.observation_space = Box( + low=-np.inf, high=np.inf, shape=(16,), dtype=np.float64 + ) MujocoEnv.__init__( self, xml_file, 4, - observation_space=observation_space, + observation_space=self.observation_space, default_camera_config=DEFAULT_CAMERA_CONFIG, **kwargs, ) diff --git a/fancy_gym/envs/mujoco/hopper_jump/hopper_jump_on_box.py b/fancy_gym/envs/mujoco/hopper_jump/hopper_jump_on_box.py index 7dab661..60d387a 100644 --- a/fancy_gym/envs/mujoco/hopper_jump/hopper_jump_on_box.py +++ b/fancy_gym/envs/mujoco/hopper_jump/hopper_jump_on_box.py @@ -4,6 +4,7 @@ from typing import Optional, Dict, Any, Tuple import numpy as np from gymnasium.core import ObsType from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML +from gymnasium import spaces MAX_EPISODE_STEPS_HOPPERJUMPONBOX = 250 @@ -36,6 +37,16 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML): self.hopper_on_box = False self.context = context self.box_x = 1 + + if exclude_current_positions_from_observation: + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64 + ) + else: + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(14,), dtype=np.float64 + ) + xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file) super().__init__(xml_file, forward_reward_weight, ctrl_cost_weight, healthy_reward, terminate_when_unhealthy, healthy_state_range, healthy_z_range, healthy_angle_range, reset_noise_scale, diff --git a/fancy_gym/envs/mujoco/hopper_throw/hopper_throw.py b/fancy_gym/envs/mujoco/hopper_throw/hopper_throw.py index bb38c88..2dd82b2 100644 --- a/fancy_gym/envs/mujoco/hopper_throw/hopper_throw.py +++ b/fancy_gym/envs/mujoco/hopper_throw/hopper_throw.py @@ -4,6 +4,7 @@ from typing import Optional, Any, Dict, Tuple import numpy as np from gymnasium.core import ObsType from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML +from gymnasium import spaces MAX_EPISODE_STEPS_HOPPERTHROW = 250 @@ -37,6 +38,16 @@ class HopperThrowEnv(HopperEnvCustomXML): self.max_episode_steps = max_episode_steps self.context = context self.goal = 0 + + if not hasattr(self, 'observation_space'): + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64 + ) + else: + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64 + ) + super().__init__(xml_file=xml_file, forward_reward_weight=forward_reward_weight, ctrl_cost_weight=ctrl_cost_weight, diff --git a/fancy_gym/envs/mujoco/hopper_throw/hopper_throw_in_basket.py b/fancy_gym/envs/mujoco/hopper_throw/hopper_throw_in_basket.py index 6d49dcb..be6b81a 100644 --- a/fancy_gym/envs/mujoco/hopper_throw/hopper_throw_in_basket.py +++ b/fancy_gym/envs/mujoco/hopper_throw/hopper_throw_in_basket.py @@ -4,6 +4,8 @@ from typing import Optional, Any, Dict, Tuple import numpy as np from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML from gymnasium.core import ObsType +from gymnasium import spaces + MAX_EPISODE_STEPS_HOPPERTHROWINBASKET = 250 @@ -43,6 +45,16 @@ class HopperThrowInBasketEnv(HopperEnvCustomXML): self.context = context self.penalty = penalty self.basket_x = 5 + + if exclude_current_positions_from_observation: + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64 + ) + else: + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64 + ) + xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file) super().__init__(xml_file=xml_file, forward_reward_weight=forward_reward_weight, diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py index 872aa75..a5d67c0 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py @@ -23,6 +23,14 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): 7 DoF table tennis environment """ + metadata = { + "render_modes": [ + "human", + "rgb_array", + "depth_array", + ], + } + def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4, goal_switching_step: int = None, enable_artificial_wind: bool = False): @@ -51,9 +59,14 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): self._artificial_force = 0. + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(9,), dtype=np.float64 + ) + MujocoEnv.__init__(self, model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), - frame_skip=frame_skip,) + frame_skip=frame_skip, + observation_space=self.observation_space) if ctxt_dim == 2: self.context_bounds = CONTEXT_BOUNDS_2DIMS diff --git a/fancy_gym/envs/mujoco/walker_2d_jump/walker_2d_jump.py b/fancy_gym/envs/mujoco/walker_2d_jump/walker_2d_jump.py index fe8d0b2..127719c 100644 --- a/fancy_gym/envs/mujoco/walker_2d_jump/walker_2d_jump.py +++ b/fancy_gym/envs/mujoco/walker_2d_jump/walker_2d_jump.py @@ -61,11 +61,11 @@ class Walker2dEnvCustomXML(Walker2dEnv): if exclude_current_positions_from_observation: observation_space = Box( - low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64 + low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64 ) else: observation_space = Box( - low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64 + low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64 ) self.observation_space = observation_space