Adding/fixing obs space definitions and metadata for various envs
This commit is contained in:
parent
f07b8a26ac
commit
ef64b0c21c
@ -5,6 +5,7 @@ import numpy as np
|
|||||||
from gymnasium import utils
|
from gymnasium import utils
|
||||||
from gymnasium.core import ObsType
|
from gymnasium.core import ObsType
|
||||||
from gymnasium.envs.mujoco import MujocoEnv
|
from gymnasium.envs.mujoco import MujocoEnv
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
|
||||||
MAX_EPISODE_STEPS_BEERPONG = 300
|
MAX_EPISODE_STEPS_BEERPONG = 300
|
||||||
FIXED_RELEASE_STEP = 62 # empirically evaluated for frame_skip=2!
|
FIXED_RELEASE_STEP = 62 # empirically evaluated for frame_skip=2!
|
||||||
@ -31,6 +32,14 @@ CUP_COLLISION_OBJ = ["cup_geom_table3", "cup_geom_table4", "cup_geom_table5", "c
|
|||||||
|
|
||||||
|
|
||||||
class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
||||||
|
metadata = {
|
||||||
|
"render_modes": [
|
||||||
|
"human",
|
||||||
|
"rgb_array",
|
||||||
|
"depth_array",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
# Small Context -> Easier. Todo: Should we do different versions?
|
# Small Context -> Easier. Todo: Should we do different versions?
|
||||||
@ -66,6 +75,10 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
self.ball_in_cup = False
|
self.ball_in_cup = False
|
||||||
self.dist_ground_cup = -1 # distance floor to cup if first floor contact
|
self.dist_ground_cup = -1 # distance floor to cup if first floor contact
|
||||||
|
|
||||||
|
self.observation_space = Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64
|
||||||
|
)
|
||||||
|
|
||||||
MujocoEnv.__init__(
|
MujocoEnv.__init__(
|
||||||
self,
|
self,
|
||||||
self.xml_path,
|
self.xml_path,
|
||||||
|
@ -27,6 +27,14 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
3. time-spatial-depend sparse reward
|
3. time-spatial-depend sparse reward
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"render_modes": [
|
||||||
|
"human",
|
||||||
|
"rgb_array",
|
||||||
|
"depth_array",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, frame_skip: int = 10):
|
def __init__(self, frame_skip: int = 10):
|
||||||
utils.EzPickle.__init__(**locals())
|
utils.EzPickle.__init__(**locals())
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
@ -40,9 +48,15 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
self._desired_rod_quat = desired_rod_quat
|
self._desired_rod_quat = desired_rod_quat
|
||||||
|
|
||||||
self._episode_energy = 0.
|
self._episode_energy = 0.
|
||||||
|
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64
|
||||||
|
)
|
||||||
|
|
||||||
MujocoEnv.__init__(self,
|
MujocoEnv.__init__(self,
|
||||||
model_path=os.path.join(os.path.dirname(__file__), "assets", "box_pushing.xml"),
|
model_path=os.path.join(os.path.dirname(__file__), "assets", "box_pushing.xml"),
|
||||||
frame_skip=self.frame_skip)
|
frame_skip=self.frame_skip,
|
||||||
|
observation_space=self.observation_space)
|
||||||
self.action_space = spaces.Box(low=-1, high=1, shape=(7,))
|
self.action_space = spaces.Box(low=-1, high=1, shape=(7,))
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
@ -45,11 +45,11 @@ class HalfCheetahEnvCustomXML(HalfCheetahEnv):
|
|||||||
|
|
||||||
if exclude_current_positions_from_observation:
|
if exclude_current_positions_from_observation:
|
||||||
observation_space = Box(
|
observation_space = Box(
|
||||||
low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
|
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
observation_space = Box(
|
observation_space = Box(
|
||||||
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
|
low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
|
||||||
)
|
)
|
||||||
|
|
||||||
MujocoEnv.__init__(
|
MujocoEnv.__init__(
|
||||||
|
@ -67,20 +67,21 @@ class HopperEnvCustomXML(HopperEnv):
|
|||||||
exclude_current_positions_from_observation
|
exclude_current_positions_from_observation
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not hasattr(self, 'observation_space'):
|
||||||
if exclude_current_positions_from_observation:
|
if exclude_current_positions_from_observation:
|
||||||
observation_space = Box(
|
self.observation_space = Box(
|
||||||
low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64
|
low=-np.inf, high=np.inf, shape=(15,), dtype=np.float64
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
observation_space = Box(
|
self.observation_space = Box(
|
||||||
low=-np.inf, high=np.inf, shape=(14,), dtype=np.float64
|
low=-np.inf, high=np.inf, shape=(16,), dtype=np.float64
|
||||||
)
|
)
|
||||||
|
|
||||||
MujocoEnv.__init__(
|
MujocoEnv.__init__(
|
||||||
self,
|
self,
|
||||||
xml_file,
|
xml_file,
|
||||||
4,
|
4,
|
||||||
observation_space=observation_space,
|
observation_space=self.observation_space,
|
||||||
default_camera_config=DEFAULT_CAMERA_CONFIG,
|
default_camera_config=DEFAULT_CAMERA_CONFIG,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
@ -4,6 +4,7 @@ from typing import Optional, Dict, Any, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium.core import ObsType
|
from gymnasium.core import ObsType
|
||||||
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
|
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
|
||||||
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
|
||||||
MAX_EPISODE_STEPS_HOPPERJUMPONBOX = 250
|
MAX_EPISODE_STEPS_HOPPERJUMPONBOX = 250
|
||||||
@ -36,6 +37,16 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
|
|||||||
self.hopper_on_box = False
|
self.hopper_on_box = False
|
||||||
self.context = context
|
self.context = context
|
||||||
self.box_x = 1
|
self.box_x = 1
|
||||||
|
|
||||||
|
if exclude_current_positions_from_observation:
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(14,), dtype=np.float64
|
||||||
|
)
|
||||||
|
|
||||||
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
|
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
|
||||||
super().__init__(xml_file, forward_reward_weight, ctrl_cost_weight, healthy_reward, terminate_when_unhealthy,
|
super().__init__(xml_file, forward_reward_weight, ctrl_cost_weight, healthy_reward, terminate_when_unhealthy,
|
||||||
healthy_state_range, healthy_z_range, healthy_angle_range, reset_noise_scale,
|
healthy_state_range, healthy_z_range, healthy_angle_range, reset_noise_scale,
|
||||||
|
@ -4,6 +4,7 @@ from typing import Optional, Any, Dict, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium.core import ObsType
|
from gymnasium.core import ObsType
|
||||||
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
|
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
|
||||||
|
from gymnasium import spaces
|
||||||
|
|
||||||
MAX_EPISODE_STEPS_HOPPERTHROW = 250
|
MAX_EPISODE_STEPS_HOPPERTHROW = 250
|
||||||
|
|
||||||
@ -37,6 +38,16 @@ class HopperThrowEnv(HopperEnvCustomXML):
|
|||||||
self.max_episode_steps = max_episode_steps
|
self.max_episode_steps = max_episode_steps
|
||||||
self.context = context
|
self.context = context
|
||||||
self.goal = 0
|
self.goal = 0
|
||||||
|
|
||||||
|
if not hasattr(self, 'observation_space'):
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__(xml_file=xml_file,
|
super().__init__(xml_file=xml_file,
|
||||||
forward_reward_weight=forward_reward_weight,
|
forward_reward_weight=forward_reward_weight,
|
||||||
ctrl_cost_weight=ctrl_cost_weight,
|
ctrl_cost_weight=ctrl_cost_weight,
|
||||||
|
@ -4,6 +4,8 @@ from typing import Optional, Any, Dict, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
|
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
|
||||||
from gymnasium.core import ObsType
|
from gymnasium.core import ObsType
|
||||||
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
|
||||||
MAX_EPISODE_STEPS_HOPPERTHROWINBASKET = 250
|
MAX_EPISODE_STEPS_HOPPERTHROWINBASKET = 250
|
||||||
|
|
||||||
@ -43,6 +45,16 @@ class HopperThrowInBasketEnv(HopperEnvCustomXML):
|
|||||||
self.context = context
|
self.context = context
|
||||||
self.penalty = penalty
|
self.penalty = penalty
|
||||||
self.basket_x = 5
|
self.basket_x = 5
|
||||||
|
|
||||||
|
if exclude_current_positions_from_observation:
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
|
||||||
|
)
|
||||||
|
|
||||||
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
|
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
|
||||||
super().__init__(xml_file=xml_file,
|
super().__init__(xml_file=xml_file,
|
||||||
forward_reward_weight=forward_reward_weight,
|
forward_reward_weight=forward_reward_weight,
|
||||||
|
@ -23,6 +23,14 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
7 DoF table tennis environment
|
7 DoF table tennis environment
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"render_modes": [
|
||||||
|
"human",
|
||||||
|
"rgb_array",
|
||||||
|
"depth_array",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
|
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
|
||||||
goal_switching_step: int = None,
|
goal_switching_step: int = None,
|
||||||
enable_artificial_wind: bool = False):
|
enable_artificial_wind: bool = False):
|
||||||
@ -51,9 +59,14 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
self._artificial_force = 0.
|
self._artificial_force = 0.
|
||||||
|
|
||||||
|
self.observation_space = spaces.Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(9,), dtype=np.float64
|
||||||
|
)
|
||||||
|
|
||||||
MujocoEnv.__init__(self,
|
MujocoEnv.__init__(self,
|
||||||
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
||||||
frame_skip=frame_skip,)
|
frame_skip=frame_skip,
|
||||||
|
observation_space=self.observation_space)
|
||||||
|
|
||||||
if ctxt_dim == 2:
|
if ctxt_dim == 2:
|
||||||
self.context_bounds = CONTEXT_BOUNDS_2DIMS
|
self.context_bounds = CONTEXT_BOUNDS_2DIMS
|
||||||
|
@ -61,11 +61,11 @@ class Walker2dEnvCustomXML(Walker2dEnv):
|
|||||||
|
|
||||||
if exclude_current_positions_from_observation:
|
if exclude_current_positions_from_observation:
|
||||||
observation_space = Box(
|
observation_space = Box(
|
||||||
low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
|
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
observation_space = Box(
|
observation_space = Box(
|
||||||
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
|
low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
|
||||||
)
|
)
|
||||||
|
|
||||||
self.observation_space = observation_space
|
self.observation_space = observation_space
|
||||||
|
Loading…
Reference in New Issue
Block a user