diff --git a/fancy_gym/envs/mujoco/air_hockey/seven_dof/tournament.py b/fancy_gym/envs/mujoco/air_hockey/seven_dof/tournament.py index 6220854..49713a0 100644 --- a/fancy_gym/envs/mujoco/air_hockey/seven_dof/tournament.py +++ b/fancy_gym/envs/mujoco/air_hockey/seven_dof/tournament.py @@ -10,15 +10,14 @@ class AirHockeyTournament(AirHockeyDouble): When the puck is on one side for more than 15 seconds the puck is reset and the player gets a penalty. If a player accumulates 3 penalties his score is reduced by 1. """ - def __init__(self, gamma=0.99, horizon=45000, viewer_params={}, agent_name="Agent", opponent_name="Opponent"): + def __init__(self, gamma=0.99, horizon=15000, viewer_params={}, agent_name="Agent", opponent_name="Opponent"): self.agent_name = agent_name self.opponent_name = opponent_name self.score = [0, 0] self.faults = [0, 0] - self.start_side = np.random.choice([1, -1]) + self.start_side = None - self.prev_side = self.start_side self.timer = 0 def custom_render_callback(viewport, context): @@ -37,6 +36,10 @@ class AirHockeyTournament(AirHockeyDouble): self.hit_range = np.array([[-0.7, -0.2], [-hit_width, hit_width]]) # Table Frame def setup(self, obs): + if self.start_side == None: + self.start_side = np.random.choice([1, -1]) + self.prev_side = self.start_side + # Initial position of the puck puck_pos = np.random.rand(2) * (self.hit_range[:, 1] - self.hit_range[:, 0]) + self.hit_range[:, 0]