Adding/fixing obs space definitions and metadata for various envs
This commit is contained in:
		
							parent
							
								
									f07b8a26ac
								
							
						
					
					
						commit
						ef64b0c21c
					
				@ -5,6 +5,7 @@ import numpy as np
 | 
			
		||||
from gymnasium import utils
 | 
			
		||||
from gymnasium.core import ObsType
 | 
			
		||||
from gymnasium.envs.mujoco import MujocoEnv
 | 
			
		||||
from gymnasium.spaces import Box
 | 
			
		||||
 | 
			
		||||
MAX_EPISODE_STEPS_BEERPONG = 300
 | 
			
		||||
FIXED_RELEASE_STEP = 62  # empirically evaluated for frame_skip=2!
 | 
			
		||||
@ -31,6 +32,14 @@ CUP_COLLISION_OBJ = ["cup_geom_table3", "cup_geom_table4", "cup_geom_table5", "c
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BeerPongEnv(MujocoEnv, utils.EzPickle):
 | 
			
		||||
    metadata = {
 | 
			
		||||
        "render_modes": [
 | 
			
		||||
            "human",
 | 
			
		||||
            "rgb_array",
 | 
			
		||||
            "depth_array",
 | 
			
		||||
        ],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def __init__(self, **kwargs):
 | 
			
		||||
        self._steps = 0
 | 
			
		||||
        # Small Context -> Easier. Todo: Should we do different versions?
 | 
			
		||||
@ -66,6 +75,10 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
 | 
			
		||||
        self.ball_in_cup = False
 | 
			
		||||
        self.dist_ground_cup = -1  # distance floor to cup if first floor contact
 | 
			
		||||
 | 
			
		||||
        self.observation_space = Box(
 | 
			
		||||
            low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        MujocoEnv.__init__(
 | 
			
		||||
            self,
 | 
			
		||||
            self.xml_path,
 | 
			
		||||
@ -208,13 +221,13 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
 | 
			
		||||
                    min_dist_coeff, final_dist_coeff, ground_contact_dist_coeff, rew_offset = 0, 1, 0, 0
 | 
			
		||||
            action_cost = 1e-4 * np.mean(action_cost)
 | 
			
		||||
            reward = rew_offset - min_dist_coeff * min_dist ** 2 - final_dist_coeff * final_dist ** 2 - \
 | 
			
		||||
                     action_cost - ground_contact_dist_coeff * self.dist_ground_cup ** 2
 | 
			
		||||
                action_cost - ground_contact_dist_coeff * self.dist_ground_cup ** 2
 | 
			
		||||
            # release step punishment
 | 
			
		||||
            min_time_bound = 0.1
 | 
			
		||||
            max_time_bound = 1.0
 | 
			
		||||
            release_time = self.release_step * self.dt
 | 
			
		||||
            release_time_rew = int(release_time < min_time_bound) * (-30 - 10 * (release_time - min_time_bound) ** 2) + \
 | 
			
		||||
                               int(release_time > max_time_bound) * (-30 - 10 * (release_time - max_time_bound) ** 2)
 | 
			
		||||
                int(release_time > max_time_bound) * (-30 - 10 * (release_time - max_time_bound) ** 2)
 | 
			
		||||
            reward += release_time_rew
 | 
			
		||||
            success = self.ball_in_cup
 | 
			
		||||
        else:
 | 
			
		||||
 | 
			
		||||
@ -27,6 +27,14 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
 | 
			
		||||
    3. time-spatial-depend sparse reward
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    metadata = {
 | 
			
		||||
        "render_modes": [
 | 
			
		||||
            "human",
 | 
			
		||||
            "rgb_array",
 | 
			
		||||
            "depth_array",
 | 
			
		||||
        ],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def __init__(self, frame_skip: int = 10):
 | 
			
		||||
        utils.EzPickle.__init__(**locals())
 | 
			
		||||
        self._steps = 0
 | 
			
		||||
@ -40,9 +48,15 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
 | 
			
		||||
        self._desired_rod_quat = desired_rod_quat
 | 
			
		||||
 | 
			
		||||
        self._episode_energy = 0.
 | 
			
		||||
 | 
			
		||||
        self.observation_space = spaces.Box(
 | 
			
		||||
            low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        MujocoEnv.__init__(self,
 | 
			
		||||
                           model_path=os.path.join(os.path.dirname(__file__), "assets", "box_pushing.xml"),
 | 
			
		||||
                           frame_skip=self.frame_skip)
 | 
			
		||||
                           frame_skip=self.frame_skip,
 | 
			
		||||
                           observation_space=self.observation_space)
 | 
			
		||||
        self.action_space = spaces.Box(low=-1, high=1, shape=(7,))
 | 
			
		||||
 | 
			
		||||
    def step(self, action):
 | 
			
		||||
 | 
			
		||||
@ -45,11 +45,11 @@ class HalfCheetahEnvCustomXML(HalfCheetahEnv):
 | 
			
		||||
 | 
			
		||||
        if exclude_current_positions_from_observation:
 | 
			
		||||
            observation_space = Box(
 | 
			
		||||
                low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
 | 
			
		||||
                low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            observation_space = Box(
 | 
			
		||||
                low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
 | 
			
		||||
                low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        MujocoEnv.__init__(
 | 
			
		||||
 | 
			
		||||
@ -67,20 +67,21 @@ class HopperEnvCustomXML(HopperEnv):
 | 
			
		||||
            exclude_current_positions_from_observation
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if exclude_current_positions_from_observation:
 | 
			
		||||
            observation_space = Box(
 | 
			
		||||
                low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            observation_space = Box(
 | 
			
		||||
                low=-np.inf, high=np.inf, shape=(14,), dtype=np.float64
 | 
			
		||||
            )
 | 
			
		||||
        if not hasattr(self, 'observation_space'):
 | 
			
		||||
            if exclude_current_positions_from_observation:
 | 
			
		||||
                self.observation_space = Box(
 | 
			
		||||
                    low=-np.inf, high=np.inf, shape=(15,), dtype=np.float64
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                self.observation_space = Box(
 | 
			
		||||
                    low=-np.inf, high=np.inf, shape=(16,), dtype=np.float64
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        MujocoEnv.__init__(
 | 
			
		||||
            self,
 | 
			
		||||
            xml_file,
 | 
			
		||||
            4,
 | 
			
		||||
            observation_space=observation_space,
 | 
			
		||||
            observation_space=self.observation_space,
 | 
			
		||||
            default_camera_config=DEFAULT_CAMERA_CONFIG,
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -4,6 +4,7 @@ from typing import Optional, Dict, Any, Tuple
 | 
			
		||||
import numpy as np
 | 
			
		||||
from gymnasium.core import ObsType
 | 
			
		||||
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
 | 
			
		||||
from gymnasium import spaces
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
MAX_EPISODE_STEPS_HOPPERJUMPONBOX = 250
 | 
			
		||||
@ -36,6 +37,16 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
 | 
			
		||||
        self.hopper_on_box = False
 | 
			
		||||
        self.context = context
 | 
			
		||||
        self.box_x = 1
 | 
			
		||||
 | 
			
		||||
        if exclude_current_positions_from_observation:
 | 
			
		||||
            self.observation_space = spaces.Box(
 | 
			
		||||
                low=-np.inf, high=np.inf, shape=(13,), dtype=np.float64
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            self.observation_space = spaces.Box(
 | 
			
		||||
                low=-np.inf, high=np.inf, shape=(14,), dtype=np.float64
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
 | 
			
		||||
        super().__init__(xml_file, forward_reward_weight, ctrl_cost_weight, healthy_reward, terminate_when_unhealthy,
 | 
			
		||||
                         healthy_state_range, healthy_z_range, healthy_angle_range, reset_noise_scale,
 | 
			
		||||
 | 
			
		||||
@ -4,6 +4,7 @@ from typing import Optional, Any, Dict, Tuple
 | 
			
		||||
import numpy as np
 | 
			
		||||
from gymnasium.core import ObsType
 | 
			
		||||
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
 | 
			
		||||
from gymnasium import spaces
 | 
			
		||||
 | 
			
		||||
MAX_EPISODE_STEPS_HOPPERTHROW = 250
 | 
			
		||||
 | 
			
		||||
@ -37,6 +38,16 @@ class HopperThrowEnv(HopperEnvCustomXML):
 | 
			
		||||
        self.max_episode_steps = max_episode_steps
 | 
			
		||||
        self.context = context
 | 
			
		||||
        self.goal = 0
 | 
			
		||||
 | 
			
		||||
        if not hasattr(self, 'observation_space'):
 | 
			
		||||
            self.observation_space = spaces.Box(
 | 
			
		||||
                low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            self.observation_space = spaces.Box(
 | 
			
		||||
                low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        super().__init__(xml_file=xml_file,
 | 
			
		||||
                         forward_reward_weight=forward_reward_weight,
 | 
			
		||||
                         ctrl_cost_weight=ctrl_cost_weight,
 | 
			
		||||
 | 
			
		||||
@ -4,6 +4,8 @@ from typing import Optional, Any, Dict, Tuple
 | 
			
		||||
import numpy as np
 | 
			
		||||
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
 | 
			
		||||
from gymnasium.core import ObsType
 | 
			
		||||
from gymnasium import spaces
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
MAX_EPISODE_STEPS_HOPPERTHROWINBASKET = 250
 | 
			
		||||
 | 
			
		||||
@ -43,6 +45,16 @@ class HopperThrowInBasketEnv(HopperEnvCustomXML):
 | 
			
		||||
        self.context = context
 | 
			
		||||
        self.penalty = penalty
 | 
			
		||||
        self.basket_x = 5
 | 
			
		||||
 | 
			
		||||
        if exclude_current_positions_from_observation:
 | 
			
		||||
            self.observation_space = spaces.Box(
 | 
			
		||||
                low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            self.observation_space = spaces.Box(
 | 
			
		||||
                low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
 | 
			
		||||
        super().__init__(xml_file=xml_file,
 | 
			
		||||
                         forward_reward_weight=forward_reward_weight,
 | 
			
		||||
 | 
			
		||||
@ -23,6 +23,14 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
 | 
			
		||||
    7 DoF table tennis environment
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    metadata = {
 | 
			
		||||
        "render_modes": [
 | 
			
		||||
            "human",
 | 
			
		||||
            "rgb_array",
 | 
			
		||||
            "depth_array",
 | 
			
		||||
        ],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
 | 
			
		||||
                 goal_switching_step: int = None,
 | 
			
		||||
                 enable_artificial_wind: bool = False):
 | 
			
		||||
@ -51,9 +59,14 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
 | 
			
		||||
 | 
			
		||||
        self._artificial_force = 0.
 | 
			
		||||
 | 
			
		||||
        self.observation_space = spaces.Box(
 | 
			
		||||
            low=-np.inf, high=np.inf, shape=(9,), dtype=np.float64
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        MujocoEnv.__init__(self,
 | 
			
		||||
                           model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
 | 
			
		||||
                           frame_skip=frame_skip,)
 | 
			
		||||
                           frame_skip=frame_skip,
 | 
			
		||||
                           observation_space=self.observation_space)
 | 
			
		||||
 | 
			
		||||
        if ctxt_dim == 2:
 | 
			
		||||
            self.context_bounds = CONTEXT_BOUNDS_2DIMS
 | 
			
		||||
 | 
			
		||||
@ -61,11 +61,11 @@ class Walker2dEnvCustomXML(Walker2dEnv):
 | 
			
		||||
 | 
			
		||||
        if exclude_current_positions_from_observation:
 | 
			
		||||
            observation_space = Box(
 | 
			
		||||
                low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
 | 
			
		||||
                low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            observation_space = Box(
 | 
			
		||||
                low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
 | 
			
		||||
                low=-np.inf, high=np.inf, shape=(19,), dtype=np.float64
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.observation_space = observation_space
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user