Allow custom XML-files for ant_env
This commit is contained in:
parent
ddf6fd73b2
commit
42003a3f9a
@ -2,7 +2,10 @@ from typing import Tuple, Union, Optional, Any, Dict
|
||||
|
||||
import numpy as np
|
||||
from gymnasium.core import ObsType
|
||||
from gymnasium.envs.mujoco.ant_v4 import AntEnv
|
||||
from gymnasium.envs.mujoco.ant_v4 import AntEnv, DEFAULT_CAMERA_CONFIG
|
||||
from gymnasium import utils
|
||||
from gymnasium.envs.mujoco import MujocoEnv
|
||||
from gymnasium.spaces import Box
|
||||
|
||||
MAX_EPISODE_STEPS_ANTJUMP = 200
|
||||
|
||||
@ -12,8 +15,74 @@ MAX_EPISODE_STEPS_ANTJUMP = 200
|
||||
# to the same structure as the Hopper, where the angles are randomized (->contexts) and the agent should jump as heigh
|
||||
# as possible, while landing at a specific target position
|
||||
|
||||
class AntEnvCustomXML(AntEnv):
|
||||
def __init__(
|
||||
self,
|
||||
xml_file="ant.xml",
|
||||
ctrl_cost_weight=0.5,
|
||||
use_contact_forces=False,
|
||||
contact_cost_weight=5e-4,
|
||||
healthy_reward=1.0,
|
||||
terminate_when_unhealthy=True,
|
||||
healthy_z_range=(0.2, 1.0),
|
||||
contact_force_range=(-1.0, 1.0),
|
||||
reset_noise_scale=0.1,
|
||||
exclude_current_positions_from_observation=True,
|
||||
**kwargs,
|
||||
):
|
||||
utils.EzPickle.__init__(
|
||||
self,
|
||||
xml_file,
|
||||
ctrl_cost_weight,
|
||||
use_contact_forces,
|
||||
contact_cost_weight,
|
||||
healthy_reward,
|
||||
terminate_when_unhealthy,
|
||||
healthy_z_range,
|
||||
contact_force_range,
|
||||
reset_noise_scale,
|
||||
exclude_current_positions_from_observation,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
class AntJumpEnv(AntEnv):
|
||||
self._ctrl_cost_weight = ctrl_cost_weight
|
||||
self._contact_cost_weight = contact_cost_weight
|
||||
|
||||
self._healthy_reward = healthy_reward
|
||||
self._terminate_when_unhealthy = terminate_when_unhealthy
|
||||
self._healthy_z_range = healthy_z_range
|
||||
|
||||
self._contact_force_range = contact_force_range
|
||||
|
||||
self._reset_noise_scale = reset_noise_scale
|
||||
|
||||
self._use_contact_forces = use_contact_forces
|
||||
|
||||
self._exclude_current_positions_from_observation = (
|
||||
exclude_current_positions_from_observation
|
||||
)
|
||||
|
||||
obs_shape = 27
|
||||
if not exclude_current_positions_from_observation:
|
||||
obs_shape += 2
|
||||
if use_contact_forces:
|
||||
obs_shape += 84
|
||||
|
||||
observation_space = Box(
|
||||
low=-np.inf, high=np.inf, shape=(obs_shape,), dtype=np.float64
|
||||
)
|
||||
|
||||
MujocoEnv.__init__(
|
||||
self,
|
||||
xml_file,
|
||||
5,
|
||||
observation_space=observation_space,
|
||||
default_camera_config=DEFAULT_CAMERA_CONFIG,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class AntJumpEnv(AntEnvCustomXML):
|
||||
"""
|
||||
Initialization changes to normal Ant:
|
||||
- healthy_reward: 1.0 -> 0.01 -> 0.0 no healthy reward needed - Paul and Marc
|
||||
|
Loading…
Reference in New Issue
Block a user