Mujoco envs: Follow new spec for render_mode

This commit is contained in:
Dominik Moritz Roth 2023-10-23 12:27:13 +02:00
parent b681129a46
commit f7a493d8e5
9 changed files with 40 additions and 26 deletions

View File

@ -101,6 +101,7 @@ class AntJumpEnv(AntEnvCustomXML):
contact_force_range=(-1.0, 1.0), contact_force_range=(-1.0, 1.0),
reset_noise_scale=0.1, reset_noise_scale=0.1,
exclude_current_positions_from_observation=True, exclude_current_positions_from_observation=True,
**kwargs
): ):
self.current_step = 0 self.current_step = 0
self.max_height = 0 self.max_height = 0
@ -113,7 +114,7 @@ class AntJumpEnv(AntEnvCustomXML):
healthy_z_range=healthy_z_range, healthy_z_range=healthy_z_range,
contact_force_range=contact_force_range, contact_force_range=contact_force_range,
reset_noise_scale=reset_noise_scale, reset_noise_scale=reset_noise_scale,
exclude_current_positions_from_observation=exclude_current_positions_from_observation) exclude_current_positions_from_observation=exclude_current_positions_from_observation, **kwargs)
def step(self, action): def step(self, action):
self.current_step += 1 self.current_step += 1

View File

@ -36,7 +36,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
"render_fps": 50 "render_fps": 50
} }
def __init__(self, frame_skip: int = 10, random_init: bool = False): def __init__(self, frame_skip: int = 10, random_init: bool = False, **kwargs):
utils.EzPickle.__init__(**locals()) utils.EzPickle.__init__(**locals())
self._steps = 0 self._steps = 0
self.init_qpos_box_pushing = np.array([0., 0., 0., -1.5, 0., 1.5, 0., 0., 0., 0.6, 0.45, 0.0, 1., 0., 0., 0.]) self.init_qpos_box_pushing = np.array([0., 0., 0., -1.5, 0., 1.5, 0., 0., 0., 0.6, 0.45, 0.0, 1., 0., 0., 0.])
@ -58,7 +58,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
MujocoEnv.__init__(self, MujocoEnv.__init__(self,
model_path=os.path.join(os.path.dirname(__file__), "assets", "box_pushing.xml"), model_path=os.path.join(os.path.dirname(__file__), "assets", "box_pushing.xml"),
frame_skip=self.frame_skip, frame_skip=self.frame_skip,
observation_space=self.observation_space) observation_space=self.observation_space, **kwargs)
self.action_space = spaces.Box(low=-1, high=1, shape=(7,)) self.action_space = spaces.Box(low=-1, high=1, shape=(7,))
def step(self, action): def step(self, action):
@ -305,8 +305,8 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
class BoxPushingDense(BoxPushingEnvBase): class BoxPushingDense(BoxPushingEnvBase):
def __init__(self, frame_skip: int = 10, random_init: bool = False): def __init__(self, **kwargs):
super(BoxPushingDense, self).__init__(frame_skip=frame_skip, random_init=random_init) super(BoxPushingDense, self).__init__(**kwargs)
def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
rod_tip_pos, rod_quat, qpos, qvel, action): rod_tip_pos, rod_quat, qpos, qvel, action):
joint_penalty = self._joint_limit_violate_penalty(qpos, joint_penalty = self._joint_limit_violate_penalty(qpos,
@ -329,8 +329,8 @@ class BoxPushingDense(BoxPushingEnvBase):
class BoxPushingTemporalSparse(BoxPushingEnvBase): class BoxPushingTemporalSparse(BoxPushingEnvBase):
def __init__(self, frame_skip: int = 10, random_init: bool = False): def __init__(self, **kwargs):
super(BoxPushingTemporalSparse, self).__init__(frame_skip=frame_skip, random_init=random_init) super(BoxPushingTemporalSparse, self).__init__(**kwargs)
def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
rod_tip_pos, rod_quat, qpos, qvel, action): rod_tip_pos, rod_quat, qpos, qvel, action):
@ -361,8 +361,8 @@ class BoxPushingTemporalSparse(BoxPushingEnvBase):
class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase): class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase):
def __init__(self, frame_skip: int = 10, random_init: bool = False): def __init__(self, **kwargs):
super(BoxPushingTemporalSpatialSparse, self).__init__(frame_skip=frame_skip, random_init=random_init) super(BoxPushingTemporalSpatialSparse, self).__init__(**kwargs)
def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
rod_tip_pos, rod_quat, qpos, qvel, action): rod_tip_pos, rod_quat, qpos, qvel, action):
@ -392,8 +392,8 @@ class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase):
class BoxPushingTemporalSpatialSparse2(BoxPushingEnvBase): class BoxPushingTemporalSpatialSparse2(BoxPushingEnvBase):
def __init__(self, frame_skip: int = 10, random_init: bool = False): def __init__(self, **kwargs):
super(BoxPushingTemporalSpatialSparse2, self).__init__(frame_skip=frame_skip, random_init=random_init) super(BoxPushingTemporalSpatialSparse2, self).__init__(**kwargs)
def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
rod_tip_pos, rod_quat, qpos, qvel, action): rod_tip_pos, rod_quat, qpos, qvel, action):
@ -428,8 +428,8 @@ class BoxPushingTemporalSpatialSparse2(BoxPushingEnvBase):
class BoxPushingNoConstraintSparse(BoxPushingEnvBase): class BoxPushingNoConstraintSparse(BoxPushingEnvBase):
def __init__(self, frame_skip: int = 10, random_init: bool = False): def __init__(self, **kwargs):
super(BoxPushingNoConstraintSparse, self).__init__(frame_skip=frame_skip, random_init=random_init) super(BoxPushingNoConstraintSparse, self).__init__(**kwargs)
def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
rod_tip_pos, rod_quat, qpos, qvel, action): rod_tip_pos, rod_quat, qpos, qvel, action):

View File

@ -74,7 +74,8 @@ class HalfCheetahJumpEnv(HalfCheetahEnvCustomXML):
reset_noise_scale=0.1, reset_noise_scale=0.1,
context=True, context=True,
exclude_current_positions_from_observation=True, exclude_current_positions_from_observation=True,
max_episode_steps=100): max_episode_steps=100,
**kwargs):
self.current_step = 0 self.current_step = 0
self.max_height = 0 self.max_height = 0
# self.max_episode_steps = max_episode_steps # self.max_episode_steps = max_episode_steps
@ -85,7 +86,8 @@ class HalfCheetahJumpEnv(HalfCheetahEnvCustomXML):
forward_reward_weight=forward_reward_weight, forward_reward_weight=forward_reward_weight,
ctrl_cost_weight=ctrl_cost_weight, ctrl_cost_weight=ctrl_cost_weight,
reset_noise_scale=reset_noise_scale, reset_noise_scale=reset_noise_scale,
exclude_current_positions_from_observation=exclude_current_positions_from_observation) exclude_current_positions_from_observation=exclude_current_positions_from_observation,
**kwargs)
def step(self, action): def step(self, action):

View File

@ -115,6 +115,7 @@ class HopperJumpEnv(HopperEnvCustomXML):
reset_noise_scale=5e-3, reset_noise_scale=5e-3,
exclude_current_positions_from_observation=False, exclude_current_positions_from_observation=False,
sparse=False, sparse=False,
**kwargs
): ):
self.sparse = sparse self.sparse = sparse
@ -141,7 +142,8 @@ class HopperJumpEnv(HopperEnvCustomXML):
healthy_z_range=healthy_z_range, healthy_z_range=healthy_z_range,
healthy_angle_range=healthy_angle_range, healthy_angle_range=healthy_angle_range,
reset_noise_scale=reset_noise_scale, reset_noise_scale=reset_noise_scale,
exclude_current_positions_from_observation=exclude_current_positions_from_observation) exclude_current_positions_from_observation=exclude_current_positions_from_observation,
**kwargs)
# increase initial height # increase initial height
self.init_qpos[1] = 1.5 self.init_qpos[1] = 1.5

View File

@ -29,7 +29,8 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
reset_noise_scale=5e-3, reset_noise_scale=5e-3,
context=True, context=True,
exclude_current_positions_from_observation=True, exclude_current_positions_from_observation=True,
max_episode_steps=250): max_episode_steps=250,
**kwargs):
self.current_step = 0 self.current_step = 0
self.max_height = 0 self.max_height = 0
self.max_episode_steps = max_episode_steps self.max_episode_steps = max_episode_steps
@ -50,7 +51,8 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file) xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
super().__init__(xml_file, forward_reward_weight, ctrl_cost_weight, healthy_reward, terminate_when_unhealthy, super().__init__(xml_file, forward_reward_weight, ctrl_cost_weight, healthy_reward, terminate_when_unhealthy,
healthy_state_range, healthy_z_range, healthy_angle_range, reset_noise_scale, healthy_state_range, healthy_z_range, healthy_angle_range, reset_noise_scale,
exclude_current_positions_from_observation) exclude_current_positions_from_observation,
**kwargs)
def step(self, action): def step(self, action):

View File

@ -32,7 +32,8 @@ class HopperThrowEnv(HopperEnvCustomXML):
reset_noise_scale=5e-3, reset_noise_scale=5e-3,
context=True, context=True,
exclude_current_positions_from_observation=True, exclude_current_positions_from_observation=True,
max_episode_steps=250): max_episode_steps=250,
**kwargs):
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file) xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
self.current_step = 0 self.current_step = 0
self.max_episode_steps = max_episode_steps self.max_episode_steps = max_episode_steps
@ -57,7 +58,8 @@ class HopperThrowEnv(HopperEnvCustomXML):
healthy_z_range=healthy_z_range, healthy_z_range=healthy_z_range,
healthy_state_range=healthy_angle_range, healthy_state_range=healthy_angle_range,
reset_noise_scale=reset_noise_scale, reset_noise_scale=reset_noise_scale,
exclude_current_positions_from_observation=exclude_current_positions_from_observation) exclude_current_positions_from_observation=exclude_current_positions_from_observation,
**kwargs)
def step(self, action): def step(self, action):
self.current_step += 1 self.current_step += 1

View File

@ -36,7 +36,8 @@ class HopperThrowInBasketEnv(HopperEnvCustomXML):
context=True, context=True,
penalty=0.0, penalty=0.0,
exclude_current_positions_from_observation=True, exclude_current_positions_from_observation=True,
max_episode_steps=250): max_episode_steps=250,
**kwargs):
self.hit_basket_reward = hit_basket_reward self.hit_basket_reward = hit_basket_reward
self.current_step = 0 self.current_step = 0
self.max_episode_steps = max_episode_steps self.max_episode_steps = max_episode_steps
@ -65,7 +66,8 @@ class HopperThrowInBasketEnv(HopperEnvCustomXML):
healthy_z_range=healthy_z_range, healthy_z_range=healthy_z_range,
healthy_angle_range=healthy_angle_range, healthy_angle_range=healthy_angle_range,
reset_noise_scale=reset_noise_scale, reset_noise_scale=reset_noise_scale,
exclude_current_positions_from_observation=exclude_current_positions_from_observation) exclude_current_positions_from_observation=exclude_current_positions_from_observation,
**kwargs)
def step(self, action): def step(self, action):

View File

@ -34,7 +34,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4, def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
goal_switching_step: int = None, goal_switching_step: int = None,
enable_artificial_wind: bool = False): enable_artificial_wind: bool = False, **kwargs):
utils.EzPickle.__init__(**locals()) utils.EzPickle.__init__(**locals())
self._steps = 0 self._steps = 0
@ -68,7 +68,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
MujocoEnv.__init__(self, MujocoEnv.__init__(self,
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
frame_skip=frame_skip, frame_skip=frame_skip,
observation_space=self.observation_space) observation_space=self.observation_space,
**kwargs)
if ctxt_dim == 2: if ctxt_dim == 2:
self.context_bounds = CONTEXT_BOUNDS_2DIMS self.context_bounds = CONTEXT_BOUNDS_2DIMS

View File

@ -97,7 +97,8 @@ class Walker2dJumpEnv(Walker2dEnvCustomXML):
reset_noise_scale=5e-3, reset_noise_scale=5e-3,
penalty=0, penalty=0,
exclude_current_positions_from_observation=True, exclude_current_positions_from_observation=True,
max_episode_steps=300): max_episode_steps=300,
**kwargs):
self.current_step = 0 self.current_step = 0
self.max_episode_steps = max_episode_steps self.max_episode_steps = max_episode_steps
self.max_height = 0 self.max_height = 0
@ -112,7 +113,8 @@ class Walker2dJumpEnv(Walker2dEnvCustomXML):
healthy_z_range=healthy_z_range, healthy_z_range=healthy_z_range,
healthy_angle_range=healthy_angle_range, healthy_angle_range=healthy_angle_range,
reset_noise_scale=reset_noise_scale, reset_noise_scale=reset_noise_scale,
exclude_current_positions_from_observation=exclude_current_positions_from_observation) exclude_current_positions_from_observation=exclude_current_positions_from_observation,
**kwargs)
def step(self, action): def step(self, action):
self.current_step += 1 self.current_step += 1