Tournament random after reset

This commit is contained in:
Mustafa Enes Batur 2023-11-20 13:47:15 +01:00
parent bedb4297f2
commit b95cf85e3d

View File

@ -10,15 +10,14 @@ class AirHockeyTournament(AirHockeyDouble):
When the puck is on one side for more than 15 seconds the puck is reset and the player gets a penalty.
If a player accumulates 3 penalties his score is reduced by 1.
"""
def __init__(self, gamma=0.99, horizon=45000, viewer_params={}, agent_name="Agent", opponent_name="Opponent"):
def __init__(self, gamma=0.99, horizon=15000, viewer_params={}, agent_name="Agent", opponent_name="Opponent"):
self.agent_name = agent_name
self.opponent_name = opponent_name
self.score = [0, 0]
self.faults = [0, 0]
self.start_side = np.random.choice([1, -1])
self.start_side = None
self.prev_side = self.start_side
self.timer = 0
def custom_render_callback(viewport, context):
@ -37,6 +36,10 @@ class AirHockeyTournament(AirHockeyDouble):
self.hit_range = np.array([[-0.7, -0.2], [-hit_width, hit_width]]) # Table Frame
def setup(self, obs):
if self.start_side == None:
self.start_side = np.random.choice([1, -1])
self.prev_side = self.start_side
# Initial position of the puck
puck_pos = np.random.rand(2) * (self.hit_range[:, 1] - self.hit_range[:, 0]) + self.hit_range[:, 0]