Adapted Mujoco Envs to new gymnasium spec
Gymnasium Mujoco Envs no longer allow overriding the used xml_file We therefore implement intermediate classes, that reimplement this feature.
This commit is contained in:
parent
1c002a235b
commit
dabfc7cafe
@ -3,12 +3,66 @@ from typing import Tuple, Union, Optional, Any, Dict
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium.core import ObsType
|
from gymnasium.core import ObsType
|
||||||
from gymnasium.envs.mujoco.half_cheetah_v4 import HalfCheetahEnv
|
from gymnasium.envs.mujoco.half_cheetah_v4 import HalfCheetahEnv, DEFAULT_CAMERA_CONFIG
|
||||||
|
|
||||||
|
from gymnasium import utils
|
||||||
|
from gymnasium.envs.mujoco import MujocoEnv
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
|
||||||
MAX_EPISODE_STEPS_HALFCHEETAHJUMP = 100
|
MAX_EPISODE_STEPS_HALFCHEETAHJUMP = 100
|
||||||
|
|
||||||
|
|
||||||
class HalfCheetahJumpEnv(HalfCheetahEnv):
|
class HalfCheetahEnvCustomXML(HalfCheetahEnv):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
xml_file,
|
||||||
|
forward_reward_weight=1.0,
|
||||||
|
ctrl_cost_weight=0.1,
|
||||||
|
reset_noise_scale=0.1,
|
||||||
|
exclude_current_positions_from_observation=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
utils.EzPickle.__init__(
|
||||||
|
self,
|
||||||
|
xml_file,
|
||||||
|
forward_reward_weight,
|
||||||
|
ctrl_cost_weight,
|
||||||
|
reset_noise_scale,
|
||||||
|
exclude_current_positions_from_observation,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._forward_reward_weight = forward_reward_weight
|
||||||
|
|
||||||
|
self._ctrl_cost_weight = ctrl_cost_weight
|
||||||
|
|
||||||
|
self._reset_noise_scale = reset_noise_scale
|
||||||
|
|
||||||
|
self._exclude_current_positions_from_observation = (
|
||||||
|
exclude_current_positions_from_observation
|
||||||
|
)
|
||||||
|
|
||||||
|
if exclude_current_positions_from_observation:
|
||||||
|
observation_space = Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
observation_space = Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
|
||||||
|
)
|
||||||
|
|
||||||
|
MujocoEnv.__init__(
|
||||||
|
self,
|
||||||
|
xml_file,
|
||||||
|
5,
|
||||||
|
observation_space=observation_space,
|
||||||
|
default_camera_config=DEFAULT_CAMERA_CONFIG,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HalfCheetahJumpEnv(HalfCheetahEnvCustomXML):
|
||||||
"""
|
"""
|
||||||
_ctrl_cost_weight 0.1 -> 0.0
|
_ctrl_cost_weight 0.1 -> 0.0
|
||||||
"""
|
"""
|
||||||
@ -41,7 +95,7 @@ class HalfCheetahJumpEnv(HalfCheetahEnv):
|
|||||||
height_after = self.get_body_com("torso")[2]
|
height_after = self.get_body_com("torso")[2]
|
||||||
self.max_height = max(height_after, self.max_height)
|
self.max_height = max(height_after, self.max_height)
|
||||||
|
|
||||||
## Didnt use fell_over, because base env also has no done condition - Paul and Marc
|
# Didnt use fell_over, because base env also has no done condition - Paul and Marc
|
||||||
# fell_over = abs(self.sim.data.qpos[2]) > 2.5 # how to figure out if the cheetah fell over? -> 2.5 oke?
|
# fell_over = abs(self.sim.data.qpos[2]) > 2.5 # how to figure out if the cheetah fell over? -> 2.5 oke?
|
||||||
# TODO: Should a fall over be checked here?
|
# TODO: Should a fall over be checked here?
|
||||||
terminated = False
|
terminated = False
|
||||||
|
@ -1,12 +1,92 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium.envs.mujoco.hopper_v4 import HopperEnv
|
from gymnasium.envs.mujoco.hopper_v4 import HopperEnv, DEFAULT_CAMERA_CONFIG
|
||||||
|
|
||||||
|
from gymnasium import utils
|
||||||
|
from gymnasium.envs.mujoco import MujocoEnv
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
|
||||||
MAX_EPISODE_STEPS_HOPPERJUMP = 250
|
MAX_EPISODE_STEPS_HOPPERJUMP = 250
|
||||||
|
|
||||||
|
|
||||||
class HopperJumpEnv(HopperEnv):
|
class HopperEnvCustomXML(HopperEnv):
|
||||||
|
"""
|
||||||
|
Initialization changes to normal Hopper:
|
||||||
|
- terminate_when_unhealthy: True -> False
|
||||||
|
- healthy_reward: 1.0 -> 2.0
|
||||||
|
- healthy_z_range: (0.7, float('inf')) -> (0.5, float('inf'))
|
||||||
|
- healthy_angle_range: (-0.2, 0.2) -> (-float('inf'), float('inf'))
|
||||||
|
- exclude_current_positions_from_observation: True -> False
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
xml_file,
|
||||||
|
forward_reward_weight=1.0,
|
||||||
|
ctrl_cost_weight=1e-3,
|
||||||
|
healthy_reward=1.0,
|
||||||
|
terminate_when_unhealthy=True,
|
||||||
|
healthy_state_range=(-100.0, 100.0),
|
||||||
|
healthy_z_range=(0.7, float("inf")),
|
||||||
|
healthy_angle_range=(-0.2, 0.2),
|
||||||
|
reset_noise_scale=5e-3,
|
||||||
|
exclude_current_positions_from_observation=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
|
||||||
|
utils.EzPickle.__init__(
|
||||||
|
self,
|
||||||
|
xml_file,
|
||||||
|
forward_reward_weight,
|
||||||
|
ctrl_cost_weight,
|
||||||
|
healthy_reward,
|
||||||
|
terminate_when_unhealthy,
|
||||||
|
healthy_state_range,
|
||||||
|
healthy_z_range,
|
||||||
|
healthy_angle_range,
|
||||||
|
reset_noise_scale,
|
||||||
|
exclude_current_positions_from_observation,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self._forward_reward_weight = forward_reward_weight
|
||||||
|
|
||||||
|
self._ctrl_cost_weight = ctrl_cost_weight
|
||||||
|
|
||||||
|
self._healthy_reward = healthy_reward
|
||||||
|
self._terminate_when_unhealthy = terminate_when_unhealthy
|
||||||
|
|
||||||
|
self._healthy_state_range = healthy_state_range
|
||||||
|
self._healthy_z_range = healthy_z_range
|
||||||
|
self._healthy_angle_range = healthy_angle_range
|
||||||
|
|
||||||
|
self._reset_noise_scale = reset_noise_scale
|
||||||
|
|
||||||
|
self._exclude_current_positions_from_observation = (
|
||||||
|
exclude_current_positions_from_observation
|
||||||
|
)
|
||||||
|
|
||||||
|
if exclude_current_positions_from_observation:
|
||||||
|
observation_space = Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
observation_space = Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(12,), dtype=np.float64
|
||||||
|
)
|
||||||
|
|
||||||
|
MujocoEnv.__init__(
|
||||||
|
self,
|
||||||
|
xml_file,
|
||||||
|
4,
|
||||||
|
observation_space=observation_space,
|
||||||
|
default_camera_config=DEFAULT_CAMERA_CONFIG,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HopperJumpEnv(HopperEnvCustomXML):
|
||||||
"""
|
"""
|
||||||
Initialization changes to normal Hopper:
|
Initialization changes to normal Hopper:
|
||||||
- terminate_when_unhealthy: True -> False
|
- terminate_when_unhealthy: True -> False
|
||||||
|
@ -3,12 +3,12 @@ from typing import Optional, Any, Dict, Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium.core import ObsType
|
from gymnasium.core import ObsType
|
||||||
from gymnasium.envs.mujoco.hopper_v4 import HopperEnv
|
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
|
||||||
|
|
||||||
MAX_EPISODE_STEPS_HOPPERTHROW = 250
|
MAX_EPISODE_STEPS_HOPPERTHROW = 250
|
||||||
|
|
||||||
|
|
||||||
class HopperThrowEnv(HopperEnv):
|
class HopperThrowEnv(HopperEnvCustomXML):
|
||||||
"""
|
"""
|
||||||
Initialization changes to normal Hopper:
|
Initialization changes to normal Hopper:
|
||||||
- healthy_reward: 1.0 -> 0.0 -> 0.1
|
- healthy_reward: 1.0 -> 0.0 -> 0.1
|
||||||
@ -104,5 +104,3 @@ class HopperThrowEnv(HopperEnv):
|
|||||||
|
|
||||||
observation = self._get_obs()
|
observation = self._get_obs()
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,13 +2,13 @@ import os
|
|||||||
from typing import Optional, Any, Dict, Tuple
|
from typing import Optional, Any, Dict, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium.envs.mujoco.hopper_v4 import HopperEnv
|
from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
|
||||||
from gymnasium.core import ObsType
|
from gymnasium.core import ObsType
|
||||||
|
|
||||||
MAX_EPISODE_STEPS_HOPPERTHROWINBASKET = 250
|
MAX_EPISODE_STEPS_HOPPERTHROWINBASKET = 250
|
||||||
|
|
||||||
|
|
||||||
class HopperThrowInBasketEnv(HopperEnv):
|
class HopperThrowInBasketEnv(HopperEnvCustomXML):
|
||||||
"""
|
"""
|
||||||
Initialization changes to normal Hopper:
|
Initialization changes to normal Hopper:
|
||||||
- healthy_reward: 1.0 -> 0.0
|
- healthy_reward: 1.0 -> 0.0
|
||||||
@ -136,6 +136,3 @@ class HopperThrowInBasketEnv(HopperEnv):
|
|||||||
|
|
||||||
observation = self._get_obs()
|
observation = self._get_obs()
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,9 +2,13 @@ import os
|
|||||||
from typing import Optional, Any, Dict, Tuple
|
from typing import Optional, Any, Dict, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium.envs.mujoco.walker2d_v4 import Walker2dEnv
|
from gymnasium.envs.mujoco.walker2d_v4 import Walker2dEnv, DEFAULT_CAMERA_CONFIG
|
||||||
from gymnasium.core import ObsType
|
from gymnasium.core import ObsType
|
||||||
|
|
||||||
|
from gymnasium import utils
|
||||||
|
from gymnasium.envs.mujoco import MujocoEnv
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
|
||||||
MAX_EPISODE_STEPS_WALKERJUMP = 300
|
MAX_EPISODE_STEPS_WALKERJUMP = 300
|
||||||
|
|
||||||
|
|
||||||
@ -12,6 +16,67 @@ MAX_EPISODE_STEPS_WALKERJUMP = 300
|
|||||||
# to the same structure as the Hopper, where the angles are randomized (->contexts) and the agent should jump as height
|
# to the same structure as the Hopper, where the angles are randomized (->contexts) and the agent should jump as height
|
||||||
# as possible, while landing at a specific target position
|
# as possible, while landing at a specific target position
|
||||||
|
|
||||||
|
class Walker2dEnvCustomXML(Walker2dEnv):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
xml_file,
|
||||||
|
forward_reward_weight=1.0,
|
||||||
|
ctrl_cost_weight=1e-3,
|
||||||
|
healthy_reward=1.0,
|
||||||
|
terminate_when_unhealthy=True,
|
||||||
|
healthy_z_range=(0.8, 2.0),
|
||||||
|
healthy_angle_range=(-1.0, 1.0),
|
||||||
|
reset_noise_scale=5e-3,
|
||||||
|
exclude_current_positions_from_observation=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
utils.EzPickle.__init__(
|
||||||
|
self,
|
||||||
|
xml_file,
|
||||||
|
forward_reward_weight,
|
||||||
|
ctrl_cost_weight,
|
||||||
|
healthy_reward,
|
||||||
|
terminate_when_unhealthy,
|
||||||
|
healthy_z_range,
|
||||||
|
healthy_angle_range,
|
||||||
|
reset_noise_scale,
|
||||||
|
exclude_current_positions_from_observation,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._forward_reward_weight = forward_reward_weight
|
||||||
|
self._ctrl_cost_weight = ctrl_cost_weight
|
||||||
|
|
||||||
|
self._healthy_reward = healthy_reward
|
||||||
|
self._terminate_when_unhealthy = terminate_when_unhealthy
|
||||||
|
|
||||||
|
self._healthy_z_range = healthy_z_range
|
||||||
|
self._healthy_angle_range = healthy_angle_range
|
||||||
|
|
||||||
|
self._reset_noise_scale = reset_noise_scale
|
||||||
|
|
||||||
|
self._exclude_current_positions_from_observation = (
|
||||||
|
exclude_current_positions_from_observation
|
||||||
|
)
|
||||||
|
|
||||||
|
if exclude_current_positions_from_observation:
|
||||||
|
observation_space = Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
observation_space = Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
|
||||||
|
)
|
||||||
|
|
||||||
|
MujocoEnv.__init__(
|
||||||
|
self,
|
||||||
|
xml_file,
|
||||||
|
4,
|
||||||
|
observation_space=observation_space,
|
||||||
|
default_camera_config=DEFAULT_CAMERA_CONFIG,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Walker2dJumpEnv(Walker2dEnv):
|
class Walker2dJumpEnv(Walker2dEnv):
|
||||||
"""
|
"""
|
||||||
@ -100,4 +165,3 @@ class Walker2dJumpEnv(Walker2dEnv):
|
|||||||
|
|
||||||
observation = self._get_obs()
|
observation = self._get_obs()
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user