diff --git a/fancy_gym/envs/mujoco/air_hockey/air_hockey_env_wrapper.py b/fancy_gym/envs/mujoco/air_hockey/air_hockey_env_wrapper.py index 508c831..3aa197c 100644 --- a/fancy_gym/envs/mujoco/air_hockey/air_hockey_env_wrapper.py +++ b/fancy_gym/envs/mujoco/air_hockey/air_hockey_env_wrapper.py @@ -7,8 +7,6 @@ from fancy_gym.envs.mujoco.air_hockey import position_control_wrapper as positio from fancy_gym.envs.mujoco.air_hockey.utils import robot_to_world from mushroom_rl.core import Environment -MAX_EPISODE_STEPS_AIRHOCKEY = 45000 # For a tournament env, the game can last up to 15 minutes - class AirHockeyEnv(Environment): metadata = {"render_modes": ["human"], "render_fps": 60} @@ -84,6 +82,8 @@ class AirHockeyEnv(Environment): self.env_info['env_name'] = self.env_name self.render_mode = render_mode + self.render_human_active = False + super().__init__(self.base_env.info) def step(self, action): @@ -111,10 +111,14 @@ class AirHockeyEnv(Environment): if self.env_info['env_name'] == "tournament": obs = np.array(np.split(obs, 2)) + + if self.render_human_active: + self.base_env.render() + return obs, reward, done, False, info def render(self): - self.base_env.render() + self.render_human_active = True def reset(self, seed=None, options={}): self.base_env.seed(seed) @@ -146,6 +150,9 @@ class AirHockeyEnv(Environment): @property def unwrapped(self): return self + + def close(self): + return if __name__ == "__main__":