Allow custom XML-files for ant_env
This commit is contained in:
parent
ddf6fd73b2
commit
42003a3f9a
@ -2,7 +2,10 @@ from typing import Tuple, Union, Optional, Any, Dict
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium.core import ObsType
|
from gymnasium.core import ObsType
|
||||||
from gymnasium.envs.mujoco.ant_v4 import AntEnv
|
from gymnasium.envs.mujoco.ant_v4 import AntEnv, DEFAULT_CAMERA_CONFIG
|
||||||
|
from gymnasium import utils
|
||||||
|
from gymnasium.envs.mujoco import MujocoEnv
|
||||||
|
from gymnasium.spaces import Box
|
||||||
|
|
||||||
MAX_EPISODE_STEPS_ANTJUMP = 200
|
MAX_EPISODE_STEPS_ANTJUMP = 200
|
||||||
|
|
||||||
@ -12,8 +15,74 @@ MAX_EPISODE_STEPS_ANTJUMP = 200
|
|||||||
# to the same structure as the Hopper, where the angles are randomized (->contexts) and the agent should jump as heigh
|
# to the same structure as the Hopper, where the angles are randomized (->contexts) and the agent should jump as heigh
|
||||||
# as possible, while landing at a specific target position
|
# as possible, while landing at a specific target position
|
||||||
|
|
||||||
|
class AntEnvCustomXML(AntEnv):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
xml_file="ant.xml",
|
||||||
|
ctrl_cost_weight=0.5,
|
||||||
|
use_contact_forces=False,
|
||||||
|
contact_cost_weight=5e-4,
|
||||||
|
healthy_reward=1.0,
|
||||||
|
terminate_when_unhealthy=True,
|
||||||
|
healthy_z_range=(0.2, 1.0),
|
||||||
|
contact_force_range=(-1.0, 1.0),
|
||||||
|
reset_noise_scale=0.1,
|
||||||
|
exclude_current_positions_from_observation=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
utils.EzPickle.__init__(
|
||||||
|
self,
|
||||||
|
xml_file,
|
||||||
|
ctrl_cost_weight,
|
||||||
|
use_contact_forces,
|
||||||
|
contact_cost_weight,
|
||||||
|
healthy_reward,
|
||||||
|
terminate_when_unhealthy,
|
||||||
|
healthy_z_range,
|
||||||
|
contact_force_range,
|
||||||
|
reset_noise_scale,
|
||||||
|
exclude_current_positions_from_observation,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
class AntJumpEnv(AntEnv):
|
self._ctrl_cost_weight = ctrl_cost_weight
|
||||||
|
self._contact_cost_weight = contact_cost_weight
|
||||||
|
|
||||||
|
self._healthy_reward = healthy_reward
|
||||||
|
self._terminate_when_unhealthy = terminate_when_unhealthy
|
||||||
|
self._healthy_z_range = healthy_z_range
|
||||||
|
|
||||||
|
self._contact_force_range = contact_force_range
|
||||||
|
|
||||||
|
self._reset_noise_scale = reset_noise_scale
|
||||||
|
|
||||||
|
self._use_contact_forces = use_contact_forces
|
||||||
|
|
||||||
|
self._exclude_current_positions_from_observation = (
|
||||||
|
exclude_current_positions_from_observation
|
||||||
|
)
|
||||||
|
|
||||||
|
obs_shape = 27
|
||||||
|
if not exclude_current_positions_from_observation:
|
||||||
|
obs_shape += 2
|
||||||
|
if use_contact_forces:
|
||||||
|
obs_shape += 84
|
||||||
|
|
||||||
|
observation_space = Box(
|
||||||
|
low=-np.inf, high=np.inf, shape=(obs_shape,), dtype=np.float64
|
||||||
|
)
|
||||||
|
|
||||||
|
MujocoEnv.__init__(
|
||||||
|
self,
|
||||||
|
xml_file,
|
||||||
|
5,
|
||||||
|
observation_space=observation_space,
|
||||||
|
default_camera_config=DEFAULT_CAMERA_CONFIG,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AntJumpEnv(AntEnvCustomXML):
|
||||||
"""
|
"""
|
||||||
Initialization changes to normal Ant:
|
Initialization changes to normal Ant:
|
||||||
- healthy_reward: 1.0 -> 0.01 -> 0.0 no healthy reward needed - Paul and Marc
|
- healthy_reward: 1.0 -> 0.01 -> 0.0 no healthy reward needed - Paul and Marc
|
||||||
|
Loading…
Reference in New Issue
Block a user