Adapted Mujoco Envs to new gymnasium spec

Gymnasium Mujoco Envs no longer allow overriding the used xml_file
We therefore implement intermediate classes, that reimplement this
feature.
This commit is contained in:
Dominik Moritz Roth 2023-05-19 15:18:14 +02:00
parent 1c002a235b
commit dabfc7cafe
5 changed files with 212 additions and 19 deletions

View File

@ -3,12 +3,66 @@ from typing import Tuple, Union, Optional, Any, Dict
import numpy as np import numpy as np
from gymnasium.core import ObsType from gymnasium.core import ObsType
from gymnasium.envs.mujoco.half_cheetah_v4 import HalfCheetahEnv from gymnasium.envs.mujoco.half_cheetah_v4 import HalfCheetahEnv, DEFAULT_CAMERA_CONFIG
from gymnasium import utils
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.spaces import Box
MAX_EPISODE_STEPS_HALFCHEETAHJUMP = 100 MAX_EPISODE_STEPS_HALFCHEETAHJUMP = 100
class HalfCheetahJumpEnv(HalfCheetahEnv): class HalfCheetahEnvCustomXML(HalfCheetahEnv):
def __init__(
self,
xml_file,
forward_reward_weight=1.0,
ctrl_cost_weight=0.1,
reset_noise_scale=0.1,
exclude_current_positions_from_observation=True,
**kwargs,
):
utils.EzPickle.__init__(
self,
xml_file,
forward_reward_weight,
ctrl_cost_weight,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs,
)
self._forward_reward_weight = forward_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
)
MujocoEnv.__init__(
self,
xml_file,
5,
observation_space=observation_space,
default_camera_config=DEFAULT_CAMERA_CONFIG,
**kwargs,
)
class HalfCheetahJumpEnv(HalfCheetahEnvCustomXML):
""" """
_ctrl_cost_weight 0.1 -> 0.0 _ctrl_cost_weight 0.1 -> 0.0
""" """
@ -41,7 +95,7 @@ class HalfCheetahJumpEnv(HalfCheetahEnv):
height_after = self.get_body_com("torso")[2] height_after = self.get_body_com("torso")[2]
self.max_height = max(height_after, self.max_height) self.max_height = max(height_after, self.max_height)
## Didnt use fell_over, because base env also has no done condition - Paul and Marc # Didnt use fell_over, because base env also has no done condition - Paul and Marc
# fell_over = abs(self.sim.data.qpos[2]) > 2.5 # how to figure out if the cheetah fell over? -> 2.5 oke? # fell_over = abs(self.sim.data.qpos[2]) > 2.5 # how to figure out if the cheetah fell over? -> 2.5 oke?
# TODO: Should a fall over be checked here? # TODO: Should a fall over be checked here?
terminated = False terminated = False

View File

@ -1,12 +1,92 @@
import os import os
import numpy as np import numpy as np
from gymnasium.envs.mujoco.hopper_v4 import HopperEnv from gymnasium.envs.mujoco.hopper_v4 import HopperEnv, DEFAULT_CAMERA_CONFIG
from gymnasium import utils
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.spaces import Box
MAX_EPISODE_STEPS_HOPPERJUMP = 250 MAX_EPISODE_STEPS_HOPPERJUMP = 250
class HopperJumpEnv(HopperEnv): class HopperEnvCustomXML(HopperEnv):
"""
Initialization changes to normal Hopper:
- terminate_when_unhealthy: True -> False
- healthy_reward: 1.0 -> 2.0
- healthy_z_range: (0.7, float('inf')) -> (0.5, float('inf'))
- healthy_angle_range: (-0.2, 0.2) -> (-float('inf'), float('inf'))
- exclude_current_positions_from_observation: True -> False
"""
def __init__(
self,
xml_file,
forward_reward_weight=1.0,
ctrl_cost_weight=1e-3,
healthy_reward=1.0,
terminate_when_unhealthy=True,
healthy_state_range=(-100.0, 100.0),
healthy_z_range=(0.7, float("inf")),
healthy_angle_range=(-0.2, 0.2),
reset_noise_scale=5e-3,
exclude_current_positions_from_observation=True,
**kwargs,
):
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
utils.EzPickle.__init__(
self,
xml_file,
forward_reward_weight,
ctrl_cost_weight,
healthy_reward,
terminate_when_unhealthy,
healthy_state_range,
healthy_z_range,
healthy_angle_range,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs
)
self._forward_reward_weight = forward_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight
self._healthy_reward = healthy_reward
self._terminate_when_unhealthy = terminate_when_unhealthy
self._healthy_state_range = healthy_state_range
self._healthy_z_range = healthy_z_range
self._healthy_angle_range = healthy_angle_range
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(12,), dtype=np.float64
)
MujocoEnv.__init__(
self,
xml_file,
4,
observation_space=observation_space,
default_camera_config=DEFAULT_CAMERA_CONFIG,
**kwargs,
)
class HopperJumpEnv(HopperEnvCustomXML):
""" """
Initialization changes to normal Hopper: Initialization changes to normal Hopper:
- terminate_when_unhealthy: True -> False - terminate_when_unhealthy: True -> False
@ -141,8 +221,8 @@ class HopperJumpEnv(HopperEnv):
noise_high[5] = 0.785 noise_high[5] = 0.785
qpos = ( qpos = (
self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq) + self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nq) +
self.init_qpos self.init_qpos
) )
qvel = ( qvel = (
# self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv) + # self.np_random.uniform(low=noise_low, high=noise_high, size=self.model.nv) +

View File

@ -3,12 +3,12 @@ from typing import Optional, Any, Dict, Tuple
import numpy as np import numpy as np
from gymnasium.core import ObsType from gymnasium.core import ObsType
from gymnasium.envs.mujoco.hopper_v4 import HopperEnv from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
MAX_EPISODE_STEPS_HOPPERTHROW = 250 MAX_EPISODE_STEPS_HOPPERTHROW = 250
class HopperThrowEnv(HopperEnv): class HopperThrowEnv(HopperEnvCustomXML):
""" """
Initialization changes to normal Hopper: Initialization changes to normal Hopper:
- healthy_reward: 1.0 -> 0.0 -> 0.1 - healthy_reward: 1.0 -> 0.0 -> 0.1
@ -104,5 +104,3 @@ class HopperThrowEnv(HopperEnv):
observation = self._get_obs() observation = self._get_obs()
return observation return observation

View File

@ -2,13 +2,13 @@ import os
from typing import Optional, Any, Dict, Tuple from typing import Optional, Any, Dict, Tuple
import numpy as np import numpy as np
from gymnasium.envs.mujoco.hopper_v4 import HopperEnv from fancy_gym.envs.mujoco.hopper_jump.hopper_jump import HopperEnvCustomXML
from gymnasium.core import ObsType from gymnasium.core import ObsType
MAX_EPISODE_STEPS_HOPPERTHROWINBASKET = 250 MAX_EPISODE_STEPS_HOPPERTHROWINBASKET = 250
class HopperThrowInBasketEnv(HopperEnv): class HopperThrowInBasketEnv(HopperEnvCustomXML):
""" """
Initialization changes to normal Hopper: Initialization changes to normal Hopper:
- healthy_reward: 1.0 -> 0.0 - healthy_reward: 1.0 -> 0.0
@ -66,7 +66,7 @@ class HopperThrowInBasketEnv(HopperEnv):
is_in_basket_x = ball_pos[0] >= basket_pos[0] and ball_pos[0] <= basket_pos[0] + self.basket_size is_in_basket_x = ball_pos[0] >= basket_pos[0] and ball_pos[0] <= basket_pos[0] + self.basket_size
is_in_basket_y = ball_pos[1] >= basket_pos[1] - (self.basket_size / 2) and ball_pos[1] <= basket_pos[1] + ( is_in_basket_y = ball_pos[1] >= basket_pos[1] - (self.basket_size / 2) and ball_pos[1] <= basket_pos[1] + (
self.basket_size / 2) self.basket_size / 2)
is_in_basket_z = ball_pos[2] < 0.1 is_in_basket_z = ball_pos[2] < 0.1
is_in_basket = is_in_basket_x and is_in_basket_y and is_in_basket_z is_in_basket = is_in_basket_x and is_in_basket_y and is_in_basket_z
if is_in_basket: if is_in_basket:
@ -136,6 +136,3 @@ class HopperThrowInBasketEnv(HopperEnv):
observation = self._get_obs() observation = self._get_obs()
return observation return observation

View File

@ -2,9 +2,13 @@ import os
from typing import Optional, Any, Dict, Tuple from typing import Optional, Any, Dict, Tuple
import numpy as np import numpy as np
from gymnasium.envs.mujoco.walker2d_v4 import Walker2dEnv from gymnasium.envs.mujoco.walker2d_v4 import Walker2dEnv, DEFAULT_CAMERA_CONFIG
from gymnasium.core import ObsType from gymnasium.core import ObsType
from gymnasium import utils
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.spaces import Box
MAX_EPISODE_STEPS_WALKERJUMP = 300 MAX_EPISODE_STEPS_WALKERJUMP = 300
@ -12,6 +16,67 @@ MAX_EPISODE_STEPS_WALKERJUMP = 300
# to the same structure as the Hopper, where the angles are randomized (->contexts) and the agent should jump as height # to the same structure as the Hopper, where the angles are randomized (->contexts) and the agent should jump as height
# as possible, while landing at a specific target position # as possible, while landing at a specific target position
class Walker2dEnvCustomXML(Walker2dEnv):
def __init__(
self,
xml_file,
forward_reward_weight=1.0,
ctrl_cost_weight=1e-3,
healthy_reward=1.0,
terminate_when_unhealthy=True,
healthy_z_range=(0.8, 2.0),
healthy_angle_range=(-1.0, 1.0),
reset_noise_scale=5e-3,
exclude_current_positions_from_observation=True,
**kwargs,
):
utils.EzPickle.__init__(
self,
xml_file,
forward_reward_weight,
ctrl_cost_weight,
healthy_reward,
terminate_when_unhealthy,
healthy_z_range,
healthy_angle_range,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs,
)
self._forward_reward_weight = forward_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight
self._healthy_reward = healthy_reward
self._terminate_when_unhealthy = terminate_when_unhealthy
self._healthy_z_range = healthy_z_range
self._healthy_angle_range = healthy_angle_range
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
)
MujocoEnv.__init__(
self,
xml_file,
4,
observation_space=observation_space,
default_camera_config=DEFAULT_CAMERA_CONFIG,
**kwargs,
)
class Walker2dJumpEnv(Walker2dEnv): class Walker2dJumpEnv(Walker2dEnv):
""" """
@ -100,4 +165,3 @@ class Walker2dJumpEnv(Walker2dEnv):
observation = self._get_obs() observation = self._get_obs()
return observation return observation