diff --git a/fancy_gym/envs/mujoco/ant_jump/ant_jump.py b/fancy_gym/envs/mujoco/ant_jump/ant_jump.py index ed6bea5..97cde0e 100644 --- a/fancy_gym/envs/mujoco/ant_jump/ant_jump.py +++ b/fancy_gym/envs/mujoco/ant_jump/ant_jump.py @@ -101,6 +101,7 @@ class AntJumpEnv(AntEnvCustomXML): contact_force_range=(-1.0, 1.0), reset_noise_scale=0.1, exclude_current_positions_from_observation=True, + **kwargs ): self.current_step = 0 self.max_height = 0 @@ -113,7 +114,7 @@ class AntJumpEnv(AntEnvCustomXML): healthy_z_range=healthy_z_range, contact_force_range=contact_force_range, reset_noise_scale=reset_noise_scale, - exclude_current_positions_from_observation=exclude_current_positions_from_observation) + exclude_current_positions_from_observation=exclude_current_positions_from_observation, **kwargs) def step(self, action): self.current_step += 1 diff --git a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py index 932e3df..9b512a8 100644 --- a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py +++ b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py @@ -36,7 +36,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): "render_fps": 50 } - def __init__(self, frame_skip: int = 10, random_init: bool = False): + def __init__(self, frame_skip: int = 10, random_init: bool = False, **kwargs): utils.EzPickle.__init__(**locals()) self._steps = 0 self.init_qpos_box_pushing = np.array([0., 0., 0., -1.5, 0., 1.5, 0., 0., 0., 0.6, 0.45, 0.0, 1., 0., 0., 0.]) @@ -58,7 +58,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): MujocoEnv.__init__(self, model_path=os.path.join(os.path.dirname(__file__), "assets", "box_pushing.xml"), frame_skip=self.frame_skip, - observation_space=self.observation_space) + observation_space=self.observation_space, **kwargs) self.action_space = spaces.Box(low=-1, high=1, shape=(7,)) def step(self, action): @@ -305,8 +305,8 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): class BoxPushingDense(BoxPushingEnvBase): - def __init__(self, frame_skip: int = 10, random_init: bool = False): - super(BoxPushingDense, self).__init__(frame_skip=frame_skip, random_init=random_init) + def __init__(self, **kwargs): + super(BoxPushingDense, self).__init__(**kwargs) def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, rod_tip_pos, rod_quat, qpos, qvel, action): joint_penalty = self._joint_limit_violate_penalty(qpos, @@ -329,8 +329,8 @@ class BoxPushingDense(BoxPushingEnvBase): class BoxPushingTemporalSparse(BoxPushingEnvBase): - def __init__(self, frame_skip: int = 10, random_init: bool = False): - super(BoxPushingTemporalSparse, self).__init__(frame_skip=frame_skip, random_init=random_init) + def __init__(self, **kwargs): + super(BoxPushingTemporalSparse, self).__init__(**kwargs) def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, rod_tip_pos, rod_quat, qpos, qvel, action): @@ -361,8 +361,8 @@ class BoxPushingTemporalSparse(BoxPushingEnvBase): class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase): - def __init__(self, frame_skip: int = 10, random_init: bool = False): - super(BoxPushingTemporalSpatialSparse, self).__init__(frame_skip=frame_skip, random_init=random_init) + def __init__(self, **kwargs): + super(BoxPushingTemporalSpatialSparse, self).__init__(**kwargs) def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, rod_tip_pos, rod_quat, qpos, qvel, action): @@ -392,8 +392,8 @@ class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase): class BoxPushingTemporalSpatialSparse2(BoxPushingEnvBase): - def __init__(self, frame_skip: int = 10, random_init: bool = False): - super(BoxPushingTemporalSpatialSparse2, self).__init__(frame_skip=frame_skip, random_init=random_init) + def __init__(self, **kwargs): + super(BoxPushingTemporalSpatialSparse2, self).__init__(**kwargs) def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, rod_tip_pos, rod_quat, qpos, qvel, action): @@ -428,8 +428,8 @@ class BoxPushingTemporalSpatialSparse2(BoxPushingEnvBase): class BoxPushingNoConstraintSparse(BoxPushingEnvBase): - def __init__(self, frame_skip: int = 10, random_init: bool = False): - super(BoxPushingNoConstraintSparse, self).__init__(frame_skip=frame_skip, random_init=random_init) + def __init__(self, **kwargs): + super(BoxPushingNoConstraintSparse, self).__init__(**kwargs) def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, rod_tip_pos, rod_quat, qpos, qvel, action): diff --git a/fancy_gym/envs/mujoco/half_cheetah_jump/half_cheetah_jump.py b/fancy_gym/envs/mujoco/half_cheetah_jump/half_cheetah_jump.py index f15a9f4..088f959 100644 --- a/fancy_gym/envs/mujoco/half_cheetah_jump/half_cheetah_jump.py +++ b/fancy_gym/envs/mujoco/half_cheetah_jump/half_cheetah_jump.py @@ -74,7 +74,8 @@ class HalfCheetahJumpEnv(HalfCheetahEnvCustomXML): reset_noise_scale=0.1, context=True, exclude_current_positions_from_observation=True, - max_episode_steps=100): + max_episode_steps=100, + **kwargs): self.current_step = 0 self.max_height = 0 # self.max_episode_steps = max_episode_steps @@ -85,7 +86,8 @@ class HalfCheetahJumpEnv(HalfCheetahEnvCustomXML): forward_reward_weight=forward_reward_weight, ctrl_cost_weight=ctrl_cost_weight, reset_noise_scale=reset_noise_scale, - exclude_current_positions_from_observation=exclude_current_positions_from_observation) + exclude_current_positions_from_observation=exclude_current_positions_from_observation, + **kwargs) def step(self, action): diff --git a/fancy_gym/envs/mujoco/hopper_jump/hopper_jump.py b/fancy_gym/envs/mujoco/hopper_jump/hopper_jump.py index b77cab1..ae431ab 100644 --- a/fancy_gym/envs/mujoco/hopper_jump/hopper_jump.py +++ b/fancy_gym/envs/mujoco/hopper_jump/hopper_jump.py @@ -115,6 +115,7 @@ class HopperJumpEnv(HopperEnvCustomXML): reset_noise_scale=5e-3, exclude_current_positions_from_observation=False, sparse=False, + **kwargs ): self.sparse = sparse @@ -141,7 +142,8 @@ class HopperJumpEnv(HopperEnvCustomXML): healthy_z_range=healthy_z_range, healthy_angle_range=healthy_angle_range, reset_noise_scale=reset_noise_scale, - exclude_current_positions_from_observation=exclude_current_positions_from_observation) + exclude_current_positions_from_observation=exclude_current_positions_from_observation, + **kwargs) # increase initial height self.init_qpos[1] = 1.5 diff --git a/fancy_gym/envs/mujoco/hopper_jump/hopper_jump_on_box.py b/fancy_gym/envs/mujoco/hopper_jump/hopper_jump_on_box.py index 506344b..c0c57c2 100644 --- a/fancy_gym/envs/mujoco/hopper_jump/hopper_jump_on_box.py +++ b/fancy_gym/envs/mujoco/hopper_jump/hopper_jump_on_box.py @@ -29,7 +29,8 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML): reset_noise_scale=5e-3, context=True, exclude_current_positions_from_observation=True, - max_episode_steps=250): + max_episode_steps=250, + **kwargs): self.current_step = 0 self.max_height = 0 self.max_episode_steps = max_episode_steps @@ -50,7 +51,8 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML): xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file) super().__init__(xml_file, forward_reward_weight, ctrl_cost_weight, healthy_reward, terminate_when_unhealthy, healthy_state_range, healthy_z_range, healthy_angle_range, reset_noise_scale, - exclude_current_positions_from_observation) + exclude_current_positions_from_observation, + **kwargs) def step(self, action): diff --git a/fancy_gym/envs/mujoco/hopper_throw/hopper_throw.py b/fancy_gym/envs/mujoco/hopper_throw/hopper_throw.py index b5afc8b..7a39cd8 100644 --- a/fancy_gym/envs/mujoco/hopper_throw/hopper_throw.py +++ b/fancy_gym/envs/mujoco/hopper_throw/hopper_throw.py @@ -32,7 +32,8 @@ class HopperThrowEnv(HopperEnvCustomXML): reset_noise_scale=5e-3, context=True, exclude_current_positions_from_observation=True, - max_episode_steps=250): + max_episode_steps=250, + **kwargs): xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file) self.current_step = 0 self.max_episode_steps = max_episode_steps @@ -57,7 +58,8 @@ class HopperThrowEnv(HopperEnvCustomXML): healthy_z_range=healthy_z_range, healthy_state_range=healthy_angle_range, reset_noise_scale=reset_noise_scale, - exclude_current_positions_from_observation=exclude_current_positions_from_observation) + exclude_current_positions_from_observation=exclude_current_positions_from_observation, + **kwargs) def step(self, action): self.current_step += 1 diff --git a/fancy_gym/envs/mujoco/hopper_throw/hopper_throw_in_basket.py b/fancy_gym/envs/mujoco/hopper_throw/hopper_throw_in_basket.py index 00d1bdb..24ad402 100644 --- a/fancy_gym/envs/mujoco/hopper_throw/hopper_throw_in_basket.py +++ b/fancy_gym/envs/mujoco/hopper_throw/hopper_throw_in_basket.py @@ -36,7 +36,8 @@ class HopperThrowInBasketEnv(HopperEnvCustomXML): context=True, penalty=0.0, exclude_current_positions_from_observation=True, - max_episode_steps=250): + max_episode_steps=250, + **kwargs): self.hit_basket_reward = hit_basket_reward self.current_step = 0 self.max_episode_steps = max_episode_steps @@ -65,7 +66,8 @@ class HopperThrowInBasketEnv(HopperEnvCustomXML): healthy_z_range=healthy_z_range, healthy_angle_range=healthy_angle_range, reset_noise_scale=reset_noise_scale, - exclude_current_positions_from_observation=exclude_current_positions_from_observation) + exclude_current_positions_from_observation=exclude_current_positions_from_observation, + **kwargs) def step(self, action): diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py index 5395de7..6a08ed1 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py @@ -34,7 +34,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4, goal_switching_step: int = None, - enable_artificial_wind: bool = False): + enable_artificial_wind: bool = False, **kwargs): utils.EzPickle.__init__(**locals()) self._steps = 0 @@ -68,7 +68,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): MujocoEnv.__init__(self, model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"), frame_skip=frame_skip, - observation_space=self.observation_space) + observation_space=self.observation_space, + **kwargs) if ctxt_dim == 2: self.context_bounds = CONTEXT_BOUNDS_2DIMS diff --git a/fancy_gym/envs/mujoco/walker_2d_jump/walker_2d_jump.py b/fancy_gym/envs/mujoco/walker_2d_jump/walker_2d_jump.py index 6ad2be0..d9085ee 100644 --- a/fancy_gym/envs/mujoco/walker_2d_jump/walker_2d_jump.py +++ b/fancy_gym/envs/mujoco/walker_2d_jump/walker_2d_jump.py @@ -97,7 +97,8 @@ class Walker2dJumpEnv(Walker2dEnvCustomXML): reset_noise_scale=5e-3, penalty=0, exclude_current_positions_from_observation=True, - max_episode_steps=300): + max_episode_steps=300, + **kwargs): self.current_step = 0 self.max_episode_steps = max_episode_steps self.max_height = 0 @@ -112,7 +113,8 @@ class Walker2dJumpEnv(Walker2dEnvCustomXML): healthy_z_range=healthy_z_range, healthy_angle_range=healthy_angle_range, reset_noise_scale=reset_noise_scale, - exclude_current_positions_from_observation=exclude_current_positions_from_observation) + exclude_current_positions_from_observation=exclude_current_positions_from_observation, + **kwargs) def step(self, action): self.current_step += 1