Adding/fixing obs space definitions and metadata for various envs

This commit is contained in:
Dominik Moritz Roth 2023-06-11 11:08:46 +02:00
parent f07b8a26ac
commit ef64b0c21c
9 changed files with 92 additions and 17 deletions

View File

@ -5,6 +5,7 @@ import numpy as np
from gymnasium import utils from gymnasium import utils
from gymnasium.core import ObsType from gymnasium.core import ObsType
from gymnasium.envs.mujoco import MujocoEnv from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.spaces import Box
MAX_EPISODE_STEPS_BEERPONG = 300 MAX_EPISODE_STEPS_BEERPONG = 300
FIXED_RELEASE_STEP = 62 # empirically evaluated for frame_skip=2! FIXED_RELEASE_STEP = 62 # empirically evaluated for frame_skip=2!
@ -31,6 +32,14 @@ CUP_COLLISION_OBJ = ["cup_geom_table3", "cup_geom_table4", "cup_geom_table5", "c
class BeerPongEnv(MujocoEnv, utils.EzPickle): class BeerPongEnv(MujocoEnv, utils.EzPickle):
metadata = {
"render_modes": [
"human",
"rgb_array",
"depth_array",
],
}
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._steps = 0 self._steps = 0
# Small Context -> Easier. Todo: Should we do different versions? # Small Context -> Easier. Todo: Should we do different versions?
@ -66,6 +75,10 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
self.ball_in_cup = False self.ball_in_cup = False
self.dist_ground_cup = -1 # distance floor to cup if first floor contact self.dist_ground_cup = -1 # distance floor to cup if first floor contact
self.observation_space = Box(
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64
)
MujocoEnv.__init__( MujocoEnv.__init__(
self, self,
self.xml_path, self.xml_path,

View File

@ -27,6 +27,14 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
3. time-spatial-depend sparse reward 3. time-spatial-depend sparse reward
""" """
metadata = {
"render_modes": [
"human",
"rgb_array",
"depth_array",
],
}
def __init__(self, frame_skip: int = 10): def __init__(self, frame_skip: int = 10):
utils.EzPickle.__init__(**locals()) utils.EzPickle.__init__(**locals())
self._steps = 0 self._steps = 0
@ -40,9 +48,15 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
self._desired_rod_quat = desired_rod_quat self._desired_rod_quat = desired_rod_quat
self._episode_energy = 0. self._episode_energy = 0.
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64
)
MujocoEnv.__init__(self, MujocoEnv.__init__(self,
model_path=os.path.join(os.path.dirname(__file__), "assets", "box_pushing.xml"), model_path=os.path.join(os.path.dirname(__file__), "assets", "box_pushing.xml"),
frame_skip=self.frame_skip) frame_skip=self.frame_skip,
observation_space=self.observation_space)
self.action_space = spaces.Box(low=-1, high=1, shape=(7,)) self.action_space = spaces.Box(low=-1, high=1, shape=(7,))
def step(self, action): def step(self, action):

View File

@ -45,11 +45,11 @@ class HalfCheetahEnvCustomXML(HalfCheetahEnv):
if exclude_current_positions_from_observation: if exclude_current_positions_from_observation:
observation_space = Box( observation_space = Box(
low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64 low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
) )
else: else:
observation_space = Box( observation_space = Box(
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64 low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
) )
MujocoEnv.__init__( MujocoEnv.__init__(

View File

@ -67,20 +67,21 @@ class HopperEnvCustomXML(HopperEnv):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
if not hasattr(self, 'observation_space'):
if exclude_current_positions_from_observation: if exclude_current_positions_from_observation:
observation_space = Box( self.observation_space = Box(
low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64 low=-np.inf, high=np.inf, shape=(15,), dtype=np.float64
) )
else: else:
observation_space = Box( self.observation_space = Box(
low=-np.inf, high=np.inf, shape=(14,), dtype=np.float64 low=-np.inf, high=np.inf, shape=(16,), dtype=np.float64
) )
MujocoEnv.__init__( MujocoEnv.__init__(
self, self,
xml_file, xml_file,
4, 4,
observation_space=observation_space, observation_space=self.observation_space,
default_camera_config=DEFAULT_CAMERA_CONFIG, default_camera_config=DEFAULT_CAMERA_CONFIG,
**kwargs, **kwargs,
) )

View File

@ -4,6 +4,7 @@ from typing import Optional, Dict, Any, Tuple
import numpy as np import numpy as np
from gymnasium.core import ObsType from gymnasium.core import ObsType
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
from gymnasium import spaces
MAX_EPISODE_STEPS_HOPPERJUMPONBOX = 250 MAX_EPISODE_STEPS_HOPPERJUMPONBOX = 250
@ -36,6 +37,16 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
self.hopper_on_box = False self.hopper_on_box = False
self.context = context self.context = context
self.box_x = 1 self.box_x = 1
if exclude_current_positions_from_observation:
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64
)
else:
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(14,), dtype=np.float64
)
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file) xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
super().__init__(xml_file, forward_reward_weight, ctrl_cost_weight, healthy_reward, terminate_when_unhealthy, super().__init__(xml_file, forward_reward_weight, ctrl_cost_weight, healthy_reward, terminate_when_unhealthy,
healthy_state_range, healthy_z_range, healthy_angle_range, reset_noise_scale, healthy_state_range, healthy_z_range, healthy_angle_range, reset_noise_scale,

View File

@ -4,6 +4,7 @@ from typing import Optional, Any, Dict, Tuple
import numpy as np import numpy as np
from gymnasium.core import ObsType from gymnasium.core import ObsType
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
from gymnasium import spaces
MAX_EPISODE_STEPS_HOPPERTHROW = 250 MAX_EPISODE_STEPS_HOPPERTHROW = 250
@ -37,6 +38,16 @@ class HopperThrowEnv(HopperEnvCustomXML):
self.max_episode_steps = max_episode_steps self.max_episode_steps = max_episode_steps
self.context = context self.context = context
self.goal = 0 self.goal = 0
if not hasattr(self, 'observation_space'):
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
)
else:
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
)
super().__init__(xml_file=xml_file, super().__init__(xml_file=xml_file,
forward_reward_weight=forward_reward_weight, forward_reward_weight=forward_reward_weight,
ctrl_cost_weight=ctrl_cost_weight, ctrl_cost_weight=ctrl_cost_weight,

View File

@ -4,6 +4,8 @@ from typing import Optional, Any, Dict, Tuple
import numpy as np import numpy as np
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
from gymnasium.core import ObsType from gymnasium.core import ObsType
from gymnasium import spaces
MAX_EPISODE_STEPS_HOPPERTHROWINBASKET = 250 MAX_EPISODE_STEPS_HOPPERTHROWINBASKET = 250
@ -43,6 +45,16 @@ class HopperThrowInBasketEnv(HopperEnvCustomXML):
self.context = context self.context = context
self.penalty = penalty self.penalty = penalty
self.basket_x = 5 self.basket_x = 5
if exclude_current_positions_from_observation:
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
)
else:
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
)
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file) xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
super().__init__(xml_file=xml_file, super().__init__(xml_file=xml_file,
forward_reward_weight=forward_reward_weight, forward_reward_weight=forward_reward_weight,

View File

@ -23,6 +23,14 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
7 DoF table tennis environment 7 DoF table tennis environment
""" """
metadata = {
"render_modes": [
"human",
"rgb_array",
"depth_array",
],
}
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4, def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
goal_switching_step: int = None, goal_switching_step: int = None,
enable_artificial_wind: bool = False): enable_artificial_wind: bool = False):
@ -51,9 +59,14 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self._artificial_force = 0. self._artificial_force = 0.
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(9,), dtype=np.float64
)
MujocoEnv.__init__(self, MujocoEnv.__init__(self,
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
frame_skip=frame_skip,) frame_skip=frame_skip,
observation_space=self.observation_space)
if ctxt_dim == 2: if ctxt_dim == 2:
self.context_bounds = CONTEXT_BOUNDS_2DIMS self.context_bounds = CONTEXT_BOUNDS_2DIMS

View File

@ -61,11 +61,11 @@ class Walker2dEnvCustomXML(Walker2dEnv):
if exclude_current_positions_from_observation: if exclude_current_positions_from_observation:
observation_space = Box( observation_space = Box(
low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64 low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
) )
else: else:
observation_space = Box( observation_space = Box(
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64 low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
) )
self.observation_space = observation_space self.observation_space = observation_space