From 42003a3f9af32788dea8aedb1857bc98f4985ac6 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Sat, 10 Jun 2023 18:47:41 +0200 Subject: [PATCH] Allow custom XML-files for ant_env --- fancy_gym/envs/mujoco/ant_jump/ant_jump.py | 73 +++++++++++++++++++++- 1 file changed, 71 insertions(+), 2 deletions(-) diff --git a/fancy_gym/envs/mujoco/ant_jump/ant_jump.py b/fancy_gym/envs/mujoco/ant_jump/ant_jump.py index fbf0804..b228195 100644 --- a/fancy_gym/envs/mujoco/ant_jump/ant_jump.py +++ b/fancy_gym/envs/mujoco/ant_jump/ant_jump.py @@ -2,7 +2,10 @@ from typing import Tuple, Union, Optional, Any, Dict import numpy as np from gymnasium.core import ObsType -from gymnasium.envs.mujoco.ant_v4 import AntEnv +from gymnasium.envs.mujoco.ant_v4 import AntEnv, DEFAULT_CAMERA_CONFIG +from gymnasium import utils +from gymnasium.envs.mujoco import MujocoEnv +from gymnasium.spaces import Box MAX_EPISODE_STEPS_ANTJUMP = 200 @@ -12,8 +15,74 @@ MAX_EPISODE_STEPS_ANTJUMP = 200 # to the same structure as the Hopper, where the angles are randomized (->contexts) and the agent should jump as heigh # as possible, while landing at a specific target position +class AntEnvCustomXML(AntEnv): + def __init__( + self, + xml_file="ant.xml", + ctrl_cost_weight=0.5, + use_contact_forces=False, + contact_cost_weight=5e-4, + healthy_reward=1.0, + terminate_when_unhealthy=True, + healthy_z_range=(0.2, 1.0), + contact_force_range=(-1.0, 1.0), + reset_noise_scale=0.1, + exclude_current_positions_from_observation=True, + **kwargs, + ): + utils.EzPickle.__init__( + self, + xml_file, + ctrl_cost_weight, + use_contact_forces, + contact_cost_weight, + healthy_reward, + terminate_when_unhealthy, + healthy_z_range, + contact_force_range, + reset_noise_scale, + exclude_current_positions_from_observation, + **kwargs, + ) -class AntJumpEnv(AntEnv): + self._ctrl_cost_weight = ctrl_cost_weight + self._contact_cost_weight = contact_cost_weight + + self._healthy_reward = healthy_reward + self._terminate_when_unhealthy = terminate_when_unhealthy + self._healthy_z_range = healthy_z_range + + self._contact_force_range = contact_force_range + + self._reset_noise_scale = reset_noise_scale + + self._use_contact_forces = use_contact_forces + + self._exclude_current_positions_from_observation = ( + exclude_current_positions_from_observation + ) + + obs_shape = 27 + if not exclude_current_positions_from_observation: + obs_shape += 2 + if use_contact_forces: + obs_shape += 84 + + observation_space = Box( + low=-np.inf, high=np.inf, shape=(obs_shape,), dtype=np.float64 + ) + + MujocoEnv.__init__( + self, + xml_file, + 5, + observation_space=observation_space, + default_camera_config=DEFAULT_CAMERA_CONFIG, + **kwargs, + ) + + +class AntJumpEnv(AntEnvCustomXML): """ Initialization changes to normal Ant: - healthy_reward: 1.0 -> 0.01 -> 0.0 no healthy reward needed - Paul and Marc