Adding/fixing obs space definitions and metadata for various envs

This commit is contained in:
Dominik Moritz Roth 2023-06-11 11:08:46 +02:00
parent f07b8a26ac
commit ef64b0c21c
9 changed files with 92 additions and 17 deletions

View File

@ -5,6 +5,7 @@ import numpy as np
from gymnasium import utils
from gymnasium.core import ObsType
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.spaces import Box
MAX_EPISODE_STEPS_BEERPONG = 300
FIXED_RELEASE_STEP = 62 # empirically evaluated for frame_skip=2!
@ -31,6 +32,14 @@ CUP_COLLISION_OBJ = ["cup_geom_table3", "cup_geom_table4", "cup_geom_table5", "c
class BeerPongEnv(MujocoEnv, utils.EzPickle):
metadata = {
"render_modes": [
"human",
"rgb_array",
"depth_array",
],
}
def __init__(self, **kwargs):
self._steps = 0
# Small Context -> Easier. Todo: Should we do different versions?
@ -66,6 +75,10 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
self.ball_in_cup = False
self.dist_ground_cup = -1 # distance floor to cup if first floor contact
self.observation_space = Box(
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64
)
MujocoEnv.__init__(
self,
self.xml_path,
@ -208,13 +221,13 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
min_dist_coeff, final_dist_coeff, ground_contact_dist_coeff, rew_offset = 0, 1, 0, 0
action_cost = 1e-4 * np.mean(action_cost)
reward = rew_offset - min_dist_coeff * min_dist ** 2 - final_dist_coeff * final_dist ** 2 - \
action_cost - ground_contact_dist_coeff * self.dist_ground_cup ** 2
action_cost - ground_contact_dist_coeff * self.dist_ground_cup ** 2
# release step punishment
min_time_bound = 0.1
max_time_bound = 1.0
release_time = self.release_step * self.dt
release_time_rew = int(release_time < min_time_bound) * (-30 - 10 * (release_time - min_time_bound) ** 2) + \
int(release_time > max_time_bound) * (-30 - 10 * (release_time - max_time_bound) ** 2)
int(release_time > max_time_bound) * (-30 - 10 * (release_time - max_time_bound) ** 2)
reward += release_time_rew
success = self.ball_in_cup
else:

View File

@ -27,6 +27,14 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
3. time-spatial-depend sparse reward
"""
metadata = {
"render_modes": [
"human",
"rgb_array",
"depth_array",
],
}
def __init__(self, frame_skip: int = 10):
utils.EzPickle.__init__(**locals())
self._steps = 0
@ -40,9 +48,15 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
self._desired_rod_quat = desired_rod_quat
self._episode_energy = 0.
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64
)
MujocoEnv.__init__(self,
model_path=os.path.join(os.path.dirname(__file__), "assets", "box_pushing.xml"),
frame_skip=self.frame_skip)
frame_skip=self.frame_skip,
observation_space=self.observation_space)
self.action_space = spaces.Box(low=-1, high=1, shape=(7,))
def step(self, action):

View File

@ -45,11 +45,11 @@ class HalfCheetahEnvCustomXML(HalfCheetahEnv):
if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
)
MujocoEnv.__init__(

View File

@ -67,20 +67,21 @@ class HopperEnvCustomXML(HopperEnv):
exclude_current_positions_from_observation
)
if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(14,), dtype=np.float64
)
if not hasattr(self, 'observation_space'):
if exclude_current_positions_from_observation:
self.observation_space = Box(
low=-np.inf, high=np.inf, shape=(15,), dtype=np.float64
)
else:
self.observation_space = Box(
low=-np.inf, high=np.inf, shape=(16,), dtype=np.float64
)
MujocoEnv.__init__(
self,
xml_file,
4,
observation_space=observation_space,
observation_space=self.observation_space,
default_camera_config=DEFAULT_CAMERA_CONFIG,
**kwargs,
)

View File

@ -4,6 +4,7 @@ from typing import Optional, Dict, Any, Tuple
import numpy as np
from gymnasium.core import ObsType
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
from gymnasium import spaces
MAX_EPISODE_STEPS_HOPPERJUMPONBOX = 250
@ -36,6 +37,16 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
self.hopper_on_box = False
self.context = context
self.box_x = 1
if exclude_current_positions_from_observation:
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64
)
else:
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(14,), dtype=np.float64
)
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
super().__init__(xml_file, forward_reward_weight, ctrl_cost_weight, healthy_reward, terminate_when_unhealthy,
healthy_state_range, healthy_z_range, healthy_angle_range, reset_noise_scale,

View File

@ -4,6 +4,7 @@ from typing import Optional, Any, Dict, Tuple
import numpy as np
from gymnasium.core import ObsType
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
from gymnasium import spaces
MAX_EPISODE_STEPS_HOPPERTHROW = 250
@ -37,6 +38,16 @@ class HopperThrowEnv(HopperEnvCustomXML):
self.max_episode_steps = max_episode_steps
self.context = context
self.goal = 0
if not hasattr(self, 'observation_space'):
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
)
else:
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
)
super().__init__(xml_file=xml_file,
forward_reward_weight=forward_reward_weight,
ctrl_cost_weight=ctrl_cost_weight,

View File

@ -4,6 +4,8 @@ from typing import Optional, Any, Dict, Tuple
import numpy as np
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
from gymnasium.core import ObsType
from gymnasium import spaces
MAX_EPISODE_STEPS_HOPPERTHROWINBASKET = 250
@ -43,6 +45,16 @@ class HopperThrowInBasketEnv(HopperEnvCustomXML):
self.context = context
self.penalty = penalty
self.basket_x = 5
if exclude_current_positions_from_observation:
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
)
else:
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
)
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
super().__init__(xml_file=xml_file,
forward_reward_weight=forward_reward_weight,

View File

@ -23,6 +23,14 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
7 DoF table tennis environment
"""
metadata = {
"render_modes": [
"human",
"rgb_array",
"depth_array",
],
}
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
goal_switching_step: int = None,
enable_artificial_wind: bool = False):
@ -51,9 +59,14 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
self._artificial_force = 0.
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(9,), dtype=np.float64
)
MujocoEnv.__init__(self,
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
frame_skip=frame_skip,)
frame_skip=frame_skip,
observation_space=self.observation_space)
if ctxt_dim == 2:
self.context_bounds = CONTEXT_BOUNDS_2DIMS

View File

@ -61,11 +61,11 @@ class Walker2dEnvCustomXML(Walker2dEnv):
if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
)
self.observation_space = observation_space