Allow custom XML-files for ant_env

This commit is contained in:
Dominik Moritz Roth 2023-06-10 18:47:41 +02:00
parent ddf6fd73b2
commit 42003a3f9a

View File

@ -2,7 +2,10 @@ from typing import Tuple, Union, Optional, Any, Dict
import numpy as np import numpy as np
from gymnasium.core import ObsType from gymnasium.core import ObsType
from gymnasium.envs.mujoco.ant_v4 import AntEnv from gymnasium.envs.mujoco.ant_v4 import AntEnv, DEFAULT_CAMERA_CONFIG
from gymnasium import utils
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.spaces import Box
MAX_EPISODE_STEPS_ANTJUMP = 200 MAX_EPISODE_STEPS_ANTJUMP = 200
@ -12,8 +15,74 @@ MAX_EPISODE_STEPS_ANTJUMP = 200
# to the same structure as the Hopper, where the angles are randomized (->contexts) and the agent should jump as heigh # to the same structure as the Hopper, where the angles are randomized (->contexts) and the agent should jump as heigh
# as possible, while landing at a specific target position # as possible, while landing at a specific target position
class AntEnvCustomXML(AntEnv):
def __init__(
self,
xml_file="ant.xml",
ctrl_cost_weight=0.5,
use_contact_forces=False,
contact_cost_weight=5e-4,
healthy_reward=1.0,
terminate_when_unhealthy=True,
healthy_z_range=(0.2, 1.0),
contact_force_range=(-1.0, 1.0),
reset_noise_scale=0.1,
exclude_current_positions_from_observation=True,
**kwargs,
):
utils.EzPickle.__init__(
self,
xml_file,
ctrl_cost_weight,
use_contact_forces,
contact_cost_weight,
healthy_reward,
terminate_when_unhealthy,
healthy_z_range,
contact_force_range,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs,
)
class AntJumpEnv(AntEnv): self._ctrl_cost_weight = ctrl_cost_weight
self._contact_cost_weight = contact_cost_weight
self._healthy_reward = healthy_reward
self._terminate_when_unhealthy = terminate_when_unhealthy
self._healthy_z_range = healthy_z_range
self._contact_force_range = contact_force_range
self._reset_noise_scale = reset_noise_scale
self._use_contact_forces = use_contact_forces
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
obs_shape = 27
if not exclude_current_positions_from_observation:
obs_shape += 2
if use_contact_forces:
obs_shape += 84
observation_space = Box(
low=-np.inf, high=np.inf, shape=(obs_shape,), dtype=np.float64
)
MujocoEnv.__init__(
self,
xml_file,
5,
observation_space=observation_space,
default_camera_config=DEFAULT_CAMERA_CONFIG,
**kwargs,
)
class AntJumpEnv(AntEnvCustomXML):
""" """
Initialization changes to normal Ant: Initialization changes to normal Ant:
- healthy_reward: 1.0 -> 0.01 -> 0.0 no healthy reward needed - Paul and Marc - healthy_reward: 1.0 -> 0.01 -> 0.0 no healthy reward needed - Paul and Marc